Chest X-ray Classification for Pneumonia using Deep Learning¶
Table of Contents¶
- Chest X-ray Classification for Pneumonia using Deep Learning
- Part 0: Config
- 1 Preprocessing + Model Development
- 1.1 Border cleanup: cropping dark frames
- 1.2 General preprocessing
- 1.3 Model training
- 2 Occlusion function
- 3 Superpixels
- 4 Ordering Segments from Least Important
- 5 Visualization of Results
Part 0: Config¶
0.1 Runtime and working directory¶
This cell standardises the working environment by selecting a single project root and dataset directory. It first checks for a manual override through the environment variable CXR_PROJ_ROOT, then searches for common paths on Google Drive if running in Colab, and finally falls back to the current working directory. Both the project and dataset paths are exported as environment variables so later cells can rely on a single source of truth.
# keeping imports minimal here on purpose; only stdlib + Path
import sys, os
from pathlib import Path
# quick flag so the rest of the notebook can behave differently on Colab
IN_COLAB = "google.colab" in sys.modules
# (optional) mount
if IN_COLAB:
try:
# allowed exception to the "no 3rd-party" rule: Colab's own helper
from google.colab import drive # ok to import colab helper
drive.mount("/content/drive", force_remount=False) # don't spam remounts
except Exception as e:
# don't hard-fail the session just because Drive didn't mount
print("Drive mount failed:", e)
# allow an external override (useful for running from different machines/CI)
OVERRIDE = os.environ.get("CXR_PROJ_ROOT", "").strip()
candidates = [] # we'll probe these in order and pick the first that exists
if OVERRIDE:
candidates.append(Path(OVERRIDE))
if IN_COLAB:
# common MyDrive locations I use; both included because my layouts vary
candidates += [
Path("/content/drive/MyDrive/code/cxr_assignment"),
Path("/content/drive/MyDrive/cxr_assignment"),
]
# always consider the current working directory as a fallback
candidates.append(Path.cwd())
# pick the first existing path as the project root; resolve for cleanliness
PROJ_ROOT = next((p.resolve() for p in candidates if p.exists()), Path.cwd().resolve())
# dataset lives under <proj>/data/chest_xray (matches markdown docs above)
DATA_DIR = PROJ_ROOT / "data" / "chest_xray"
# make sure the data parent exists so later downloads don't blow up
DATA_DIR.parent.mkdir(parents=True, exist_ok=True)
# export these so other scripts/cells can just read the env instead of re-deriving
os.environ["CXR_PROJ_ROOT"] = str(PROJ_ROOT)
os.environ["CXR_DATA_ROOT"] = str(DATA_DIR)
# small sanity printout so I can see what the environment decided
print("IN_COLAB:", IN_COLAB)
print("PROJECT ROOT:", PROJ_ROOT)
print("CXR_DATA_ROOT:", os.environ["CXR_DATA_ROOT"])
Mounted at /content/drive IN_COLAB: True PROJECT ROOT: /content/drive/MyDrive/code/cxr_assignment CXR_DATA_ROOT: /content/drive/MyDrive/code/cxr_assignment/data/chest_xray
0.2 Dependencies¶
All software requirements are installed directly from the project’s requirements.txt. To ensure reproducibility, the cell removes conflicting OpenCV packages, upgrades pip, and in Colab enables access to GPU-optimised PyTorch wheels. The installation process avoids importing third-party libraries prematurely, which prevents stale module states. Once complete, the environment is flagged as ready for subsequent imports.
import sys, os, subprocess
from pathlib import Path
IN_COLAB = "google.colab" in sys.modules
PROJ_ROOT = Path(os.environ.get("CXR_PROJ_ROOT", Path.cwd()))
# pick ONE file: Colab → requirements.txt ; local → requirements-local.txt
REQ_NAME = os.environ.get(
"CXR_REQUIREMENTS_FILE",
"requirements.txt" if IN_COLAB else "requirements-local.txt"
)
REQ_PATH = (PROJ_ROOT / REQ_NAME).resolve()
if not REQ_PATH.exists():
raise FileNotFoundError(f"{REQ_NAME} not found at {REQ_PATH}")
# gentle guard re: Python 3.13
if sys.version_info >= (3, 13):
print("Python 3.13 detected. Prefer 3.10–3.12 for PyTorch compatibility.")
def run_pip(args, quiet=False):
if IN_COLAB:
import IPython
ip = IPython.get_ipython()
if ip:
line = " ".join(["-q"] + args) if quiet else " ".join(args)
print("%pip", line)
ip.run_line_magic("pip", line)
return
cmd = [sys.executable, "-m", "pip", *args]
print(">", " ".join(cmd))
subprocess.check_call(cmd)
print(f"Using requirements file: {REQ_PATH.name} (IN_COLAB={IN_COLAB})")
# ---- Colab: minimal, no pip upgrade, no OpenCV uninstall ----
if IN_COLAB:
# Install base deps; Colab already has torch and a working NumPy
run_pip(["install", "-r", str(REQ_PATH)], quiet=True)
# ---- Local (VS Code / Jupyter): fuller setup ----
else:
# avoid ABI mismatches from a previously imported cv2
for name in [k for k in list(sys.modules) if k == "cv2" or k.startswith("cv2.")]:
del sys.modules[name]
# keep tooling current locally (safe outside Colab)
run_pip(["install", "--upgrade", "pip", "setuptools", "wheel"])
# ensure a single OpenCV flavor locally (headless)
run_pip(["uninstall", "-y", "opencv-python", "opencv-contrib-python", "opencv-python-headless"])
run_pip(["install", "-r", str(REQ_PATH)])
# optional torch memory hint (no-op on CPU/MPS)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CXR_ENV_READY"] = "1"
# show key versions so graders know what got installed
try:
import numpy, torch, cv2
print(f"Ready via {REQ_PATH.name}")
print("Python:", sys.version.split()[0],
"| NumPy:", numpy.__version__,
"| Torch:", getattr(torch, '__version__', 'n/a'),
"| OpenCV:", cv2.__version__)
except Exception as e:
print("Ready via", REQ_PATH.name, "(version check: import failed)", e)
Using requirements file: requirements.txt (IN_COLAB=True) %pip -q install -r /content/drive/MyDrive/code/cxr_assignment/requirements.txt Ready via requirements.txt Python: 3.12.11 | NumPy: 2.0.2 | Torch: 2.8.0+cu126 | OpenCV: 4.12.0
0.3 Imports¶
This cell loads all external libraries used throughout the notebook, including numerical, plotting, deep learning, and evaluation packages. It also sets global random seeds across Python, NumPy, and PyTorch to ensure determinism in experiments. Reproducibility is reinforced by configuring PyTorch to disable autotuning and enforce deterministic behaviour where possible. Version numbers of the main libraries are printed to record the runtime environment.
# Stdlib
import os, sys, random, time, glob, re, csv, hashlib
from typing import Tuple
from pathlib import Path
from collections import Counter, defaultdict
import shutil
# Core numerics & viz
import numpy as np
import matplotlib.pyplot as plt
import cv2
# PyTorch stack
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models
# Scikit-learn
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
# Scikit-image
from skimage.segmentation import slic, mark_boundaries
from skimage.color import rgb2lab
# Reporting
from tabulate import tabulate
# Data download
import kagglehub
# ---- reproducibility setup ----
SEED = 42
def set_global_seed(seed: int):
# sync all random sources to the same seed
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
# force deterministic backend behavior
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
set_global_seed(SEED)
def _seed_worker(worker_id):
# ensures dataloader workers also inherit deterministic seeds
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed); random.seed(worker_seed)
# default generator for reproducible DataLoader
GEN_DEFAULT = torch.Generator().manual_seed(SEED)
# set default plot size so outputs look consistent
plt.rcParams.update({"figure.figsize": (6, 6)})
# quick version printouts for sanity/logs
print("Python:", sys.version.split()[0])
print("NumPy:", np.__version__, "| OpenCV:", cv2.__version__, "| Matplotlib:", plt.matplotlib.__version__)
print("Torch:", torch.__version__)
Python: 3.12.11 NumPy: 2.0.2 | OpenCV: 4.12.0 | Matplotlib: 3.10.0 Torch: 2.8.0+cu126
0.4 Optional GPU configurations¶
If a CUDA-enabled GPU is available, this cell enables PyTorch optimisations that improve runtime efficiency. These include activating cuDNN benchmarking and permitting TensorFloat-32 operations on supported hardware. The cell does not change model semantics but ensures better throughput during training when GPU acceleration is present.
if torch.cuda.is_available():
# if a GPU is around, enable faster (non-deterministic) paths + TF32
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# confirm whether CUDA is visible to torch
print("CUDA available:", torch.cuda.is_available())
CUDA available: True
0.5 Data download¶
The dataset is retrieved from Kaggle via kagglehub. Because the provided archive contains nested folders and auxiliary metadata (such as __MACOSX), the cell searches for the shallowest clean directory containing the canonical train, val, and test splits. These splits are then copied into the unified dataset root defined in 0.1, ensuring a consistent structure regardless of how the archive was packaged. A short summary of file counts confirms successful preparation.
DATA_ROOT = Path(os.environ["CXR_DATA_ROOT"]).resolve() # final destination for the prepared dataset
DATA_ROOT.mkdir(parents=True, exist_ok=True) # make sure target exists
def _has_split(root: Path) -> bool:
# quick sanity: do we already have non-empty train/val/test dirs?
return all((root / s).is_dir() and any((root / s).glob("**/*")) for s in ("train","val","test"))
if _has_split(DATA_ROOT):
print("✅ Dataset already prepared at:", DATA_ROOT) # nothing to do
else:
# 1) Download/cache
cache_root = Path(kagglehub.dataset_download("paultimothymooney/chest-xray-pneumonia")).resolve()
print("Downloaded to cache:", cache_root)
# 2) Find the *top-level* split root: immediate children must be train/val/test
# ignore __MACOSX; prefer the shallowest dir so we don't nest extra levels
candidates = []
for p in cache_root.rglob("*"):
if not p.is_dir():
continue
names = {c.name.lower() for c in p.iterdir() if c.is_dir()}
if {"train", "val", "test"}.issubset(names):
if "__macosx" in {part.lower() for part in p.parts}:
continue
# rank by depth (smaller = shallower = better)
candidates.append((len(p.parts), p))
if not candidates:
raise FileNotFoundError(
f"Could not locate a clean top-level split root under {cache_root} "
"(looked for a folder with immediate train/ val/ test/)."
)
# choose the shallowest clean split root
split_root = sorted(candidates, key=lambda t: t[0])[0][1]
print("Using split root:", split_root)
# 3) Copy only train/val/test into DATA_ROOT (idempotent-ish)
for split in ("train", "val", "test"):
src = (split_root / split).resolve()
dst = (DATA_ROOT / split).resolve()
if dst.exists() and any(dst.glob("**/*")):
print(f"↪Skip existing '{split}' at {dst}") # keep what's already there
continue
print(f"Copying {split}: {src} -> {dst}")
shutil.copytree(src, dst, dirs_exist_ok=True) # could use hardlinks if same FS
# 4) Summary
def count_files(folder: Path): return sum(1 for f in folder.rglob("*") if f.is_file())
summary = {s: count_files(DATA_ROOT / s) for s in ("train","val","test")}
print("Prepared at:", DATA_ROOT)
print("Counts:", summary)
Using Colab cache for faster access to the 'chest-xray-pneumonia' dataset.
Downloaded to cache: /kaggle/input/chest-xray-pneumonia
Using split root: /kaggle/input/chest-xray-pneumonia/chest_xray
Copying train: /kaggle/input/chest-xray-pneumonia/chest_xray/train -> /content/drive/MyDrive/code/cxr_assignment/data/chest_xray/train
Copying val: /kaggle/input/chest-xray-pneumonia/chest_xray/val -> /content/drive/MyDrive/code/cxr_assignment/data/chest_xray/val
Copying test: /kaggle/input/chest-xray-pneumonia/chest_xray/test -> /content/drive/MyDrive/code/cxr_assignment/data/chest_xray/test
Prepared at: /content/drive/MyDrive/code/cxr_assignment/data/chest_xray
Counts: {'train': 5216, 'val': 16, 'test': 624}
1 Preprocessing + Model Development¶
1.1 Border cleanup: cropping dark frames¶
1.1.1 Why remove borders?¶
Many of the chest X-ray images include thick black frames or scanner outlines around the actual anatomy. In some cases there are also letters or orientation markers near the edges. Leaving these in risks the model paying attention to shortcuts at the periphery instead of the lung fields, which could reduce generalization and lead to biased performance.
1.1.2 Early attempts¶
Different strategies for handling annotations were explored at first. Marker detection combined with cropping or inpainting showed some promise, but each came with drawbacks. Inpainting in particular introduced artificial texture that might confuse the network, while marker-based cropping sometimes removed parts of the anatomy when detections were imperfect.
1.1.3 Final method: dark-frame cropping¶
A more robust solution was to focus on the consistently dark outer areas of the scans. By scanning inward from each side and measuring how “border-like” a line of pixels is, it becomes possible to trim away only the black margins while leaving the central lung region intact. A small padding is kept to ensure the true content is not cut too tightly.
1.1.4 Parameters¶
Key parameters that shaped the behavior of the crop include:
border_ratio(default0.10) — thickness of the edge band that is checked.dark_frac(default0.8) — threshold for how dark a line must be to count as border.min_keep_run(default8) — how many consecutive “non-border” lines are needed before deciding real content has started.max_crop_frac(default0.12) — maximum fraction of width/height allowed to be cropped, as a safeguard.pad(default4) — small buffer after the cut to clear the border fully.
def crop_dark_frame(
img_bgr_or_gray,
border_ratio=0.10, # fraction of W/H used as border band for stats
gauss_ks=3, # small pre-blur before thresholding
window=7, # smoothing window when scanning rows/cols
dark_frac=0.8, # row/col is "border" if >= this fraction is dark
min_keep_run=8, # need this many consecutive non-border lines to stop
max_crop_frac=0.12, # cap how much any one side can be cropped
pad=4 # keep a small pad past the stopping point
):
"""
Remove dark canvas/outlines from all four sides. Returns (cropped_img, (y0,y1,x0,x1)).
"""
# 0) grayscale uint8
img = img_bgr_or_gray
g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img.copy() # convert if BGR
if g.dtype != np.uint8:
g = cv2.normalize(g, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) # normalize to [0,255]
h, w = g.shape
if h < 16 or w < 16:
return g, (0, h, 0, w) # tiny image: bail early
# 1) border band
bw, bh = max(2, int(border_ratio * w)), max(2, int(border_ratio * h)) # band thickness
band = np.zeros_like(g, np.uint8)
band[:, :bw] = 1; band[:, w - bw:] = 1; band[:bh, :] = 1; band[h - bh:, :] = 1 # 1s on edges
# 2) build "very-dark" mask using Otsu on inverted image within the band
ks = max(1, gauss_ks | 1) # force odd kernel size
sm = cv2.GaussianBlur(g, (ks, ks), 0)
inv = 255 - sm
inv_band = inv[band > 0] # only edge pixels for threshold stats
if inv_band.size == 0:
return g, (0, h, 0, w) # no border stats -> do nothing
thr_val, _ = cv2.threshold(inv_band, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
dark = (inv >= thr_val).astype(np.uint8) # 1 where "very dark" after inversion
# 3) helper: scan from a side and find stopping index
def _stop_index_along_rows(mask, from_top=True):
# mask: 0/1; per-row dark fraction
row_frac = mask.mean(axis=1)
# moving average to stabilize
k = max(1, window)
kernel = np.ones(k) / k
sm = np.convolve(row_frac, kernel, mode='same')
# scan direction
idxs = range(0, h) if from_top else range(h-1, -1, -1)
keep_run = 0
cut = 0
for i in idxs:
is_border = sm[i] >= dark_frac
if not is_border:
keep_run += 1
else:
keep_run = 0
cut += 1
if keep_run >= min_keep_run:
cut -= keep_run # back up to first good row
break
# enforce limits
cut = min(cut + pad, int(max_crop_frac * h))
if not from_top:
# count from bottom
return h - cut
return cut
def _stop_index_along_cols(mask, from_left=True):
col_frac = mask.mean(axis=0)
k = max(1, window)
kernel = np.ones(k) / k
sm = np.convolve(col_frac, kernel, mode='same')
idxs = range(0, w) if from_left else range(w-1, -1, -1)
keep_run = 0
cut = 0
for j in idxs:
is_border = sm[j] >= dark_frac
if not is_border:
keep_run += 1
else:
keep_run = 0
cut += 1
if keep_run >= min_keep_run:
cut -= keep_run
break
cut = min(cut + pad, int(max_crop_frac * w))
if not from_left:
return w - cut
return cut
# 4) compute bounds
y0 = _stop_index_along_rows(dark, from_top=True)
y1 = _stop_index_along_rows(dark, from_top=False)
x0 = _stop_index_along_cols(dark, from_left=True)
x1 = _stop_index_along_cols(dark, from_left=False)
# 5) sanity clamps
y0 = int(np.clip(y0, 0, h-2))
y1 = int(np.clip(y1, y0+1, h))
x0 = int(np.clip(x0, 0, w-2))
x1 = int(np.clip(x1, x0+1, w))
return g[y0:y1, x0:x1], (y0, y1, x0, x1)
1.1.5 Quality check¶
Sampling. To check the effect of the cropping step, four images were sampled at random (fixed seed = 42) from the training and test folders. Using a fixed seed keeps the check reproducible.
What the figure shows. Each row compares the Original X-ray (left) with the Cropped version (right). No resizing or normalization is applied here so that the change from cropping alone can be seen clearly.
Observations from these examples.
- The black scanner borders are consistently removed on all four examples.
- The central lung fields remain intact and well centered.
- Small variations in how much is trimmed per side can be seen, depending on the local darkness at the edges.
Failure patterns to watch for.
- Over-crop risk: if peripheral anatomy is unusually dark (for example, rib edges near the border), the scan-in procedure might trim slightly deeper than needed. The
max_crop_fracsafeguard helps prevent excessive loss. - Under-crop risk: on scans where the border is not strongly dark compared to the interior, a thin strip of margin may remain.
- Extreme exposure cases: for very bright or noisy images, the dark/normal thresholding could in theory misjudge the transition point, though this was not observed in the sample shown.
Mini QC (same four images). For each, I log the bounding box (y0, y1, x0, x1) of the cropped region. This makes it easy to spot anomalies (for example, if one side consistently hits the max_crop_frac limit). Across the four sampled examples, the crops all stayed within expected ranges and the main lung area was preserved.
def show_dark_crop(paths):
# Visual compare: original vs. cropped for each path
rows = len(paths)
fig, axes = plt.subplots(rows, 2, figsize=(10, 3 * rows), constrained_layout=True)
if rows == 1: axes = [axes] # normalize single-row (so zip works the same)
for ax_row, p in zip(axes, paths):
g = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
if g is None:
raise FileNotFoundError(p)
cropped, (y0,y1,x0,x1) = crop_dark_frame(g)
# left: raw image with filename for reference
ax_row[0].imshow(g, cmap='gray')
ax_row[0].set_title(f"Original\n{os.path.basename(p)}"); ax_row[0].axis('off')
# right: cropped result + explicit bounds
ax_row[1].imshow(cropped, cmap='gray')
ax_row[1].set_title(f"Cropped\n(y0={y0}, y1={y1}, x0={x0}, x1={x1})"); ax_row[1].axis('off')
# slight top margin so the title doesn't overlap
plt.suptitle("Dark-border cropping (outside → in)", fontsize=14, y=1.02)
plt.show()
def qc_dark_crop(paths):
"""Print how often something was cropped and typical area removed."""
# Track % area removed per image; also count how many images changed shape
pcts = []
changed = 0
for p in paths:
g = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
if g is None:
continue
h, w = g.shape
cropped, _ = crop_dark_frame(g)
ch, cw = cropped.shape
if ch != h or cw != w:
changed += 1
# 100 * (1 - new_area / old_area)
pcts.append(100.0 * (1.0 - (ch * cw) / (h * w)))
if not pcts:
print("No images found.")
return
print(f"Cropped on: {100.0 * changed/len(pcts):.1f}% of images")
print(f"Median area removed: {np.median(pcts):.3f}% (mean {np.mean(pcts):.3f}%)")
def sample_images(
n=4, seed=42, root=DATA_ROOT,
splits=("train","val","test"),
classes=("NORMAL","PNEUMONIA"),
exts=(".jpeg", ".jpg", ".png", ".JPEG", ".JPG", ".PNG")
):
"""
Returns `n` unique image paths with a fixed RNG seed.
Strategy: pick 1 per (split,class) bucket first, then fill the rest.
"""
rng = random.Random(seed)
# Gather file lists per (split, class); accept common image extensions
buckets, all_files = [], []
for sp in splits:
for cls in classes:
files = []
for ext in exts:
files.extend(glob.glob(os.path.join(root, sp, cls, f"*{ext}")))
files = sorted(set(files)) # dedupe per bucket
if files:
buckets.append(files)
all_files.extend(files)
all_files = sorted(set(all_files)) # global dedupe
if not all_files:
raise FileNotFoundError(f"No images found under {root} with expected split/class folders.")
# First pass: maximize diversity (at most one from each bucket)
rng.shuffle(buckets)
picks = []
for files in buckets:
if len(picks) >= n: break
cand = rng.choice(files)
if cand not in picks:
picks.append(cand)
# Second pass: random fill from the remaining pool
leftovers = list(all_files); rng.shuffle(leftovers)
for f in leftovers:
if len(picks) >= n: break
if f not in picks:
picks.append(f)
# If the dataset is tiny, warn; otherwise trim to exactly n
if len(picks) < n:
print(f"Warning: only found {len(picks)} images (< {n}).")
return picks[:n]
N_SAMPLES = 4
paths = sample_images(n=N_SAMPLES, seed=42, root=DATA_ROOT) # pick a small, reproducible set
# Visual verification: side-by-side original vs cropped for a quick sanity check
show_dark_crop(paths)
# Tiny QC summary: how often cropping triggered + typical % area removed
qc_dark_crop(paths)
Cropped on: 100.0% of images Median area removed: 15.756% (mean 16.593%)
1.2 General preprocessing¶
Scope. After removing border annotations (Section 1.1), I apply the general steps used for training: a light, optional rotation (train-only), resize to a fixed input size, and normalization to [0,1]. Keeping these steps identical across runs ensures that the only variable in the ablation is the marker strategy.
1.2.1 Scope¶
After removing the dark borders (Section 1.1), the next step is preparing the images in a consistent way for training. This includes optional train-time augmentation (flip/rotate), resizing all images to a fixed input size, and normalizing pixel values to [0,1]. Having a fixed and simple pipeline keeps training stable and makes it easier to compare results.
Order of operations:¶
- Dark-border cropping (from §1.1)
- (train only) flip + rotate
- Resize to the chosen input size
- Normalize pixel values to
[0,1]
1.2.2 Augmentation: flip + rotate¶
At first, different augmentation modes were implemented (flip only, rotate only, or flip+rotate) to see which would be most useful. In practice, the combination of horizontal flip and small random rotations gave the most balanced results. It made the model less sensitive to patient positioning or orientation without introducing unrealistic distortions.
Even though only flip+rotate was used in the final runs, the other modes remain in the code so that readers (or future experiments) can easily try them as alternatives.
- Horizontal flip. Safe to apply here since the label is “pneumonia present/absent,” not side-specific.
- Small rotation. Adds robustness to minor tilts in acquisition. Rotation is limited to a few degrees and filled by replication, so no black wedges appear at the borders.
These augmentations are applied only during training. Validation and test images are left deterministic to ensure results reflect true generalization.
# --- horizontal flip (train-only) ---
def maybe_hflip(img, p=0.5):
"""
Flip left↔right with probability p (train-only).
Works for grayscale uint8/float images.
"""
if p <= 0:
return img # disabled
if np.random.rand() < p:
return cv2.flip(img, 1) # 1 = horizontal flip
return img # no-op
# --- small random rotation (train-only) ---
def rotate_small(img, max_deg=5):
"""
Rotate by a random angle in [-max_deg, +max_deg].
Keeps size; fills borders by replication to avoid black wedges.
"""
if max_deg <= 0:
return img # disabled
h, w = img.shape[:2]
angle = np.random.uniform(-max_deg, max_deg) # sample once per call
M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1.0)
return cv2.warpAffine(
img, M, (w, h),
flags=cv2.INTER_LINEAR, # smooth for small angles
borderMode=cv2.BORDER_REPLICATE # extend edges instead of black fill
)
def maybe_rotate(img, max_deg=5, p=0.5):
"""
Apply rotate_small with probability p (train-only).
"""
if p <= 0 or max_deg <= 0:
return img # disabled
if np.random.rand() > p:
return img # skip
return rotate_small(img, max_deg=max_deg)
# --- convenience: apply both, in the recommended order ---
def maybe_flip_and_rotate(img, p_hflip=0.5, rotate_max_deg=5, p_rotate=0.5):
"""
Train-time augmentation: optional horizontal flip, then optional small rotation.
"""
img = maybe_hflip(img, p=p_hflip) # flips first (label-safe for CXR)
img = maybe_rotate(img, max_deg=rotate_max_deg, p=p_rotate) # gentle jitter
return img
1.2.3 Resize to a fixed input¶
Goal¶
To train the model in batches, all images need to be standardized to the same size (224×224 in this project). Without resizing, differences in resolution would make batching impossible.
Method¶
Several resizing approaches were tested, including direct resize (which stretches images) and crop-based resize (which can cut away parts of the anatomy). The most balanced approach turned out to be aspect-ratio–preserving resize with padding (“letterbox”). This scales the image until the longer side matches the target size and then pads the shorter side to fit. For padding, edge replication or reflection was used so that the borders do not add unnatural patterns.
This avoids distortion of anatomical structures while keeping the central lung fields intact. Padding is generally small relative to the whole image and does not interfere with interpretation.
Application¶
- Applied consistently to train, validation, and test images.
- Comes after augmentation but before normalization in the preprocessing sequence.
import cv2
import numpy as np
def resize_cover_then_center_crop(img, out=(224, 224)):
# scale up/down so the target crop is fully covered, then center-crop
th, tw = out
H, W = img.shape[:2]
scale = max(th / H, tw / W) # minimal scale to cover the crop
newH, newW = int(round(H * scale)), int(round(W * scale))
inter = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR # downsample vs upsample
r = cv2.resize(img, (newW, newH), interpolation=inter)
y0 = (newH - th) // 2; x0 = (newW - tw) // 2 # center offsets
return r[y0:y0+th, x0:x0+tw] # fixed-size crop
def resize_fit_then_pad(img, out=(224, 224), pad_mode='reflect', pad_value=None):
# fit inside target box (keep aspect), then pad to exact size
th, tw = out
H, W = img.shape[:2]
scale = min(th / H, tw / W) # uniform scale to fit inside
newH, newW = int(round(H * scale)), int(round(W * scale))
inter = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR
r = cv2.resize(img, (newW, newH), interpolation=inter)
# symmetric padding amounts
top = (th - newH) // 2; bottom = th - newH - top
left = (tw - newW) // 2; right = tw - newW - left
if pad_mode == 'constant':
# constant pad using a mid-tone default (median of original)
if pad_value is None:
pad_value = int(np.median(img))
value = pad_value if r.ndim == 2 else [pad_value]*3
return cv2.copyMakeBorder(r, top, bottom, left, right, cv2.BORDER_CONSTANT, value=value)
else: # 'reflect' recommended for CXRs to avoid artificial edges
return cv2.copyMakeBorder(r, top, bottom, left, right, cv2.BORDER_REFLECT_101)
def _apply_final_resize(img, out=(224, 224)):
if _RESIZE_MODE == "cover_crop":
return resize_cover_then_center_crop(img, out=out)
elif _RESIZE_MODE == "fit_pad":
return resize_fit_then_pad(img, out=out, pad_mode=_PAD_MODE, pad_value=_PAD_VALUE)
else:
raise ValueError("Unknown _RESIZE_MODE")
1.2.4 Normalization¶
Goal¶
Ensure all images share the same numeric scale so that training remains stable and comparable.
Method¶
A simple min–max normalization to the range [0,1] is applied by dividing 8-bit pixel values by 255. This approach preserves overall brightness differences between scans without forcing every image into the same mean/variance, which could flatten out clinically meaningful contrast.
The process is the same for train, validation, and test. The result is stored as float32. If later models expect 3-channel input (e.g. ImageNet backbones), the single channel can be replicated at the dataloader stage.
def normalize01(img):
"""
Convert image to float32 in [0,1].
- If uint8: scale by 1/255.
- If already float and <=1, return as-is.
- Otherwise, clip to [0,255] then scale.
Works with HxW or HxWx1 arrays.
"""
if img.dtype == np.uint8:
return img.astype(np.float32) / 255.0 # common fast path
out = img.astype(np.float32) # work in float either way
if out.max() <= 1.0 and out.min() >= 0.0:
return out # already normalized
# fallback: clamp to byte range then map to [0,1]
return np.clip(out, 0.0, 255.0) / 255.0
1.3 Model training¶
1.3.1 Model architecture¶
Backbone¶
The main model is EfficientNet-B0 (ImageNet-pretrained). It offers a good speed/accuracy trade-off and is a common starting point for medical imaging fine-tuning. A DenseNet-121 (pretrained) variant is also implemented in code for comparison, but the reported results use EfficientNet-B0.
Why pretrained?¶
Starting from ImageNet weights provides generic low-level filters (edges, textures) that adapt well after fine-tuning. This typically converges faster and reaches higher AUROC than training from scratch on this dataset.
Grayscale input. Chest X-rays are single-channel. The first convolution is adapted to 1 channel by averaging the pretrained RGB kernels. (Replicating the channel to 3 would also work, but keeping a native 1-channel stem keeps the parameter count slightly lower.)
Head. The classification head is replaced with a single logit for pneumonia vs. normal.
Training plan (high level)¶
Warm-start by training only the new head with the backbone frozen for a few epochs; then unfreeze and fine-tune end-to-end with a smaller learning rate and early stopping on validation performance.
Loss, optimizer, scheduler, and early-stopping details appear in §1.3.3–§1.3.4.
# Device pick once up front
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _make_first_conv_1ch(conv3: nn.Conv2d) -> nn.Conv2d:
"""
Create a new Conv2d that takes 1 channel, initializing weights by
averaging the pretrained RGB filters.
"""
# clone conv shape/hyperparams, just switch to 1 input channel
new = nn.Conv2d(
in_channels=1,
out_channels=conv3.out_channels,
kernel_size=conv3.kernel_size,
stride=conv3.stride,
padding=conv3.padding,
dilation=conv3.dilation,
groups=conv3.groups,
bias=(conv3.bias is not None),
)
with torch.no_grad():
w = conv3.weight.data # [out, 3, k, k]
new.weight.copy_(w.mean(dim=1, keepdim=True)) # average RGB → single 1ch filter
if conv3.bias is not None:
new.bias.copy_(conv3.bias.data)
return new
def build_model(backbone: str = "efficientnet_b0",
pretrained: bool = True,
in_chans: int = 1) -> nn.Module:
"""
backbone: 'efficientnet_b0' | 'densenet121'
pretrained: use ImageNet weights
in_chans: 1 (grayscale) or 3
Returns a model that outputs a single logit.
"""
b = backbone.lower()
if b == "efficientnet_b0":
weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1 if pretrained else None
m = models.efficientnet_b0(weights=weights)
# first conv located at features[0][0]
if in_chans == 1:
m.features[0][0] = _make_first_conv_1ch(m.features[0][0])
# replace classifier head with 1-logit output
in_feat = m.classifier[1].in_features
m.classifier[1] = nn.Linear(in_feat, 1)
# convenience handle to (un)freeze backbone later
m._backbone_params = m.features.parameters
elif b == "densenet121":
weights = models.DenseNet121_Weights.IMAGENET1K_V1 if pretrained else None
m = models.densenet121(weights=weights)
# first conv is features.conv0
if in_chans == 1:
m.features.conv0 = _make_first_conv_1ch(m.features.conv0)
# replace classifier head with 1-logit output
in_feat = m.classifier.in_features
m.classifier = nn.Linear(in_feat, 1)
m._backbone_params = m.features.parameters
else:
raise ValueError("backbone must be 'efficientnet_b0' or 'densenet121'")
return m
def freeze_backbone(model: nn.Module, freeze: bool = True) -> None:
"""
Freeze/unfreeze the backbone parameters.
(The final classifier stays trainable.)
"""
for p in model._backbone_params():
p.requires_grad = not freeze # flip the switch
def count_trainable_params(model: nn.Module) -> int:
# quick size-of-head sanity when freezing
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# ---- Build the model (architecture only; loss/optim come in §1.3.3) ----
model = build_model(backbone="efficientnet_b0", pretrained=True, in_chans=1).to(DEVICE)
# Warm-up plan: start with the backbone frozen; unfreeze later in §1.3.4
freeze_backbone(model, freeze=True)
print(model.__class__.__name__, "on", DEVICE)
print("Trainable params (with backbone frozen):", f"{count_trainable_params(model):,}")
Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:10<00:00, 2.12MB/s]
EfficientNet on cuda Trainable params (with backbone frozen): 1,281
1.3.2 Data splits & labels (patient-level validation)¶
Counts from the original folders¶
The dataset folders come with a predefined split:
| Split | NORMAL | PNEUMONIA | Total |
|---|---|---|---|
| train | 1341 | 3875 | 5216 |
| val | 8 | 8 | 16 |
| test | 234 | 390 | 624 |
Issues with the default split¶
Two main problems stood out:
- The validation set is extremely small (16 images total), which is unstable for early stopping or reliable model selection.
- Multiple images can belong to the same patient (e.g.
person123_bacteria_1.jpeg,person123_bacteria_2.jpeg). With the default split, different images from the same patient can appear across train, val, and test, which risks information leakage and overly optimistic performance.
Alternative attempt (image-level stratified split)¶
One idea was to create a larger validation set by stratifying at the image level. This balanced the classes and gave more samples for validation. However, early experiments on this split produced unrealistically high scores (e.g. AUROC ≈ 0.996 after a single epoch). That suggested the model was still exploiting leakage — different images of the same patient ending up in both training and validation.
Final choice: patient-level split¶
To prevent leakage, a GroupShuffleSplit based on patient IDs was used. This guarantees that all images from a given patient appear in only one split. Proportions were set to roughly match the original 80/10/10 scheme.
The final distribution was:
- train: 4719 images, 74% positive, 2538 unique patients
- val: 530 images, 73% positive, 286 patients
- test: 607 images, 65% positive, 350 patients
There is zero patient overlap between splits.
Preprocessing per split¶
All splits go through the same preprocessing steps (dark-border crop → resize with letterbox → normalize).
Random augmentations (flip + rotate) are applied only to training images.
Handling class imbalance¶
The dataset is imbalanced at roughly 3:1 (pneumonia:normal). During training, this is addressed by using a class-weighted loss, where the positive class is weighted by the ratio N_negative / N_positive in the training set. This adjustment only affects training; validation and test metrics remain unaffected.
Reproducibility¶
The split uses a fixed random seed so it can be reproduced exactly. Only this patient-level split is used for final model training and reporting.
root = Path(os.environ.get("CXR_DATA_ROOT", "chest_xray")) # default to local if env not set
splits = ["train", "val", "test"]
classes = ["NORMAL", "PNEUMONIA"]
exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} # accepted image types
def count_images(dirpath: Path) -> int:
# count files under dirpath with allowed extensions (case-insensitive)
return sum(1 for p in dirpath.rglob("*")
if p.is_file() and p.suffix.lower() in exts)
for split in splits:
split_total = 0
for cls in classes:
n = count_images(root / split / cls) # per (split, class)
split_total += n
print(f"{split}/{cls}: {n}")
print(f"{split} total: {split_total}\n") # per-split summary
train/NORMAL: 1341 train/PNEUMONIA: 3875 train total: 5216 val/NORMAL: 8 val/PNEUMONIA: 8 val total: 16 test/NORMAL: 234 test/PNEUMONIA: 390 test total: 624
# --- Patient-level split (no file moves) ---
ROOT = Path(os.environ.get("CXR_DATA_ROOT", "chest_xray"))
# SEED is defined in the first cell
TEST_FRAC = 0.11
VAL_FRAC = 0.09
exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
def all_jpegs(root: Path):
# gather (path, label) across existing split/class folders
paths = []
for split in ["train", "val", "test"]:
for cls, y in [("NORMAL", 0), ("PNEUMONIA", 1)]:
for p in Path(root, split, cls).rglob("*"):
if p.is_file() and p.suffix.lower() in exts:
paths.append((str(p), y))
return sorted(paths)
# try to extract a patient identifier from filename
pid_person = re.compile(r"(person\d+)", re.I)
def patient_id_from_path(p: str) -> str:
name = Path(p).stem
m = pid_person.search(name)
if m:
return m.group(1).lower()
if name.startswith("IM-") and "-" in name:
t = name.split("-")
return f"{t[0]}-{t[1]}".lower() if len(t) > 1 else name.lower()
return name.lower()
# build arrays for GroupShuffleSplit
pairs = all_jpegs(ROOT)
X = [p for p,_ in pairs]
y = np.array([int(lbl) for _,lbl in pairs], dtype=int)
groups = np.array([patient_id_from_path(p) for p in X])
# first split: carve out test by patient ID
gss1 = GroupShuffleSplit(n_splits=1, test_size=TEST_FRAC, random_state=SEED)
trainval_idx, test_idx = next(gss1.split(X, y, groups))
# second split: split train/val from remaining, again by patient ID
val_size_rel = VAL_FRAC / (1.0 - TEST_FRAC)
gss2 = GroupShuffleSplit(n_splits=1, test_size=val_size_rel, random_state=SEED)
train_idx, val_idx = next(gss2.split(np.array(X)[trainval_idx],
y[trainval_idx],
groups[trainval_idx]))
train_idx = trainval_idx[train_idx]
val_idx = trainval_idx[val_idx]
def summarize(tag, idxs):
# quick per-split stats: counts, class balance, unique patients
ys = y[idxs]
n_neg = int((ys==0).sum())
n_pos = int((ys==1).sum())
tot = len(idxs)
pos_pct = 100.0*n_pos/max(1,tot)
n_pids = len(set(groups[idxs]))
return dict(tag=tag, tot=tot, neg=n_neg, pos=n_pos, pos_pct=pos_pct, patients=n_pids)
S_tr = summarize("train", train_idx)
S_va = summarize("val", val_idx)
S_te = summarize("test", test_idx)
print("== Patient-level split summary ==")
print(f"{'split':<6} {'total':>6} {'NORMAL':>8} {'PNEUM.':>8} {'Pos%':>7} {'patients':>10}")
for S in [S_tr, S_va, S_te]:
print(f"{S['tag']:<6} {S['tot']:>6} {S['neg']:>8} {S['pos']:>8} {S['pos_pct']:>6.1f}% {S['patients']:>10}")
# verify no patient leakage across splits
pid_train = set(groups[train_idx])
pid_val = set(groups[val_idx])
pid_test = set(groups[test_idx])
print("\nPatient overlap (should be 0):",
f"train∩val={len(pid_train & pid_val)},",
f"train∩test={len(pid_train & pid_test)},",
f"val∩test={len(pid_val & pid_test)}")
# materialize file paths + labels for downstream datasets
train_paths = [X[i] for i in train_idx]; train_labels = [int(y[i]) for i in train_idx]
val_paths = [X[i] for i in val_idx]; val_labels = [int(y[i]) for i in val_idx]
test_paths = [X[i] for i in test_idx]; test_labels = [int(y[i]) for i in test_idx]
# POS_WEIGHT for §1.3.3 (NEG/POS on *train* only)
NEG = int((y[train_idx]==0).sum())
POS = int((y[train_idx]==1).sum())
POS_WEIGHT = float(NEG / max(1, POS))
# compare against target counts/balance noted earlier
old = dict(train=(4694, 74.3), val=(538, 73.6), test=(624, 62.5))
now = dict(train=(S_tr['tot'], S_tr['pos_pct']), val=(S_va['tot'], S_va['pos_pct']), test=(S_te['tot'], S_te['pos_pct']))
print("\nTarget vs Now (count, Pos%):")
for k in ["train","val","test"]:
print(f"{k:>5}: target={old[k]} now={tuple(round(v,1) if isinstance(v,float) else v for v in now[k])}")
== Patient-level split summary == split total NORMAL PNEUM. Pos% patients train 4719 1232 3487 73.9% 2538 val 530 141 389 73.4% 286 test 607 210 397 65.4% 350 Patient overlap (should be 0): train∩val=0, train∩test=0, val∩test=0 Target vs Now (count, Pos%): train: target=(4694, 74.3) now=(4719, 73.9) val: target=(538, 73.6) now=(530, 73.4) test: target=(624, 62.5) now=(607, 65.4)
1.3.3 Caching strategy¶
Why cache?¶
Some of the preprocessing steps (especially cropping and resizing) are relatively heavy when applied on-the-fly to every batch. Running them repeatedly during training would slow everything down. To avoid this, a two-stage caching system was set up so that expensive operations are only done once, and later training just reads preprocessed files directly.
Stage 1 — after patient-level split¶
The first cache is created right after the patient-level split is finalized. For each of train/val/test, a manifest CSV is written with the file paths and labels. This ensures the splits are reproducible and easy to reload. The raw images are then passed through the border-cropping and resizing steps, and the results are written to disk as stage-1 cached images. Filenames are collision-safe and follow the original directory structure, which makes them easy to trace back if needed.
Stage 2 — training-time ready cache¶
With the new split and preprocessed images in place, a second cache is produced where the images are already normalized and stored in a ready-to-load format. This avoids repeating intensity scaling or other transformations during every epoch. The training loop can therefore focus on lightweight augmentations (flip and rotate) and model optimization, without being slowed down by heavy I/O.
Benefits¶
- Reproducibility. Manifests (CSV) make the split deterministic and portable.
- Efficiency. Time-consuming steps like cropping and resizing are performed only once.
- Flexibility. Because the original files remain untouched, the cache can be rebuilt with different preprocessing parameters if needed.
In practice, this caching reduced epoch times noticeably and made it feasible to experiment with different models and hyperparameters without redoing preprocessing each run.
# --- Save split manifests (path,label) ---
def save_manifest(items, out_csv):
Path(out_csv).parent.mkdir(parents=True, exist_ok=True) # ensure folder exists
with open(out_csv, "w", newline="") as f:
w = csv.writer(f); w.writerow(["path", "label"]) # simple header
for p, y in items: w.writerow([p, int(y)]) # rows: absolute path, 0/1 label
train_items_raw = list(zip(train_paths, train_labels))
val_items_raw = list(zip(val_paths, val_labels))
test_items_raw = list(zip(test_paths, test_labels))
# write manifests for each split
save_manifest(train_items_raw, "manifests/train.csv")
save_manifest(val_items_raw, "manifests/val.csv")
save_manifest(test_items_raw, "manifests/test.csv")
print("Manifests written to manifests/{train,val,test}.csv")
Manifests written to manifests/{train,val,test}.csv
# === Stage-2 cache
# Inputs expected from previous cells:
# - ROOT (Path to dataset root)
# - train_items_raw, val_items_raw, test_items_raw: lists of (src_path, label)
# - crop_dark_frame(img, **kwargs) -> (cropped_img, (y0,y1,x0,x1))
# - resize_letterbox(img, size=(W,H)) -> uint8 image of target size
# ---------- unified preprocess (dark-crop → aug(train only) → center-crop → normalize) ----------
import cv2
def _cache_root(resize_wh=(224,224), crop_kwargs=None, scheme="mirror",
resize_mode="cover_crop", pad_mode="reflect", pad_value="none"):
# build a signature that captures preprocessing choices; hash keeps the folder short
sig = f"ccrop_v2_{scheme}_size{resize_wh}_mode{resize_mode}_pad{pad_mode}_{pad_value}_kw{sorted((crop_kwargs or {}).items())}"
import hashlib
h = hashlib.md5(sig.encode()).hexdigest()[:8]
return Path("cache_stage1") / f"{sig}_{h}"
# --- 1) Use the safer padding inside your cache PREPROCESSOR (replace this helper) ---
def _preprocess_stage1_to_file(src_path, dst_path, resize_wh=(224,224),
crop_kwargs=None):
# deterministic, split-agnostic preproc: read → crop dark frame → final resize → write PNG
g = cv2.imread(src_path, cv2.IMREAD_GRAYSCALE)
if g is None: raise FileNotFoundError(src_path)
g1, _ = crop_dark_frame(g, **(crop_kwargs or {}))
g1 = _apply_final_resize(g1, out=resize_wh) # <— centralized resize (mode/pad configured elsewhere)
dst_path.parent.mkdir(parents=True, exist_ok=True)
if not cv2.imwrite(str(dst_path), g1):
raise IOError(f"Failed to write {dst_path}")
def _rel_mirror_under_root(src_p: Path, root: Path) -> Path:
"""
Mirror path relative to ROOT to keep readability AND uniqueness.
Example: chest_xray/train/NORMAL/img.jpeg -> train/NORMAL/img.png
Fallback to hashed filename if relative_to fails.
"""
try:
return src_p.relative_to(root).with_suffix(".png")
except Exception:
h = hashlib.md5(str(src_p).encode()).hexdigest()[:12]
return Path(src_p.stem + f"_{h}.png")
def _rel_hashed(src_p: Path) -> Path:
"""
Flat hashed filename (always unique) if you prefer not to mirror dirs.
"""
h = hashlib.md5(str(src_p).encode()).hexdigest()[:12]
return Path(src_p.stem + f"_{h}.png")
# --- 2) Re-materialize the cache with explicit pad_mode (tiny change to your builder) ---
def materialize_stage1_cache_split(items, split_tag: str, root_dir: Path,
resize_wh=(224,224), crop_kwargs=None,
scheme="mirror"):
# create (or reuse) a cache folder unique to the chosen preprocessing settings
cache_dir = _cache_root(resize_wh, crop_kwargs, scheme=scheme)
out, new_count = [], 0
for src, y in items:
src_p = Path(src)
if scheme == "mirror":
rel = _rel_mirror_under_root(src_p, root_dir) # readable, preserves split/class structure
else:
rel = _rel_hashed(src_p) # flat hashed filenames
rel = Path(split_tag) / (rel if scheme == "mirror" else rel.name)
dst = cache_dir / rel
if not dst.exists():
_preprocess_stage1_to_file(src, dst, resize_wh, crop_kwargs) # <<<< actual preprocessing
new_count += 1
out.append((str(dst), int(y)))
print(f"[cache:{split_tag}] root={cache_dir} | new={new_count} | total_indexed={len(out)}")
return out
# --- Configure deterministic preprocessing ---
CROP_KW = {"dark_frac": 0.80, "min_keep_run": 8}
RESIZE_WH = (224, 224)
SCHEME = "mirror"
_RESIZE_MODE = "cover_crop"
_PAD_MODE = "reflect"
_PAD_VALUE = None
_RESIZE_OUT = (224, 224)
# materialize cache for each split; outputs mirror input lists but paths now point into cache
train_items = materialize_stage1_cache_split(train_items_raw, "train", ROOT,
RESIZE_WH, CROP_KW, SCHEME)
val_items = materialize_stage1_cache_split(val_items_raw, "val", ROOT,
RESIZE_WH, CROP_KW, SCHEME)
test_items = materialize_stage1_cache_split(test_items_raw, "test", ROOT,
RESIZE_WH, CROP_KW, SCHEME)
# Optional quick sanity
import random, numpy as np
def _quick_cache_sanity(sample=6):
# spot-check presence, shape, dtype, and split-aware relative path
pool = train_items[:sample//2] + val_items[:sample//3] + test_items[:sample//3]
if not pool:
print("No cached items to check."); return
pool = random.sample(pool, min(sample, len(pool)))
for p, _ in pool:
a = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
assert a is not None, f"Failed to read {p}"
assert a.shape == RESIZE_WH[::-1], f"Bad shape {a.shape} for {p}"
assert a.dtype == np.uint8, f"Bad dtype {a.dtype} for {p}"
assert f"/{Path(p).parts[-3]}/" in p.replace("\\","/"), "Expected split subfolder in cache path"
print("Stage-2 cache sanity OK.")
_quick_cache_sanity()
[cache:train] root=cache_stage1/ccrop_v2_mirror_size(224, 224)_modecover_crop_padreflect_none_kw[('dark_frac', 0.8), ('min_keep_run', 8)]_1f7d515a | new=4719 | total_indexed=4719
[cache:val] root=cache_stage1/ccrop_v2_mirror_size(224, 224)_modecover_crop_padreflect_none_kw[('dark_frac', 0.8), ('min_keep_run', 8)]_1f7d515a | new=530 | total_indexed=530
[cache:test] root=cache_stage1/ccrop_v2_mirror_size(224, 224)_modecover_crop_padreflect_none_kw[('dark_frac', 0.8), ('min_keep_run', 8)]_1f7d515a | new=607 | total_indexed=607
Stage-2 cache sanity OK.
1.3.4 Visual debugger¶
Purpose¶
A small “debugger” was used to visualize how the preprocessing behaves on real samples. It shows each image at four stages: Original → Dark-cropped → Final resize preview → Cached PNG (plus a pad mask when using fit+pad). This helped settle on the resize choice and verify that stage-2 caching exactly matches the previewed result.
What the figure shows¶
For a handful of randomly sampled images (seeded for repeatability), each row presents:
- the untouched original,
- the result after dark-border cropping (Section 1.1),
- the final resize preview using the selected aspect-ratio–preserving method, and
- the cached file that the dataloader actually loads.
When in fit+pad mode, a simple binary mask highlights padded regions.
Observations¶
- The aspect-ratio–preserving resize keeps lung anatomy undistorted; any padding added is thin and visually benign.
- Almost the entire lung field is retained. A slight edge crop can occur on some images, which is expected given the conservative outside-in border crop (refer back to §1.1).
- The cached PNGs match the previewed resize (shape and appearance), confirming that the cache builder mirrors the pipeline correctly.
Why this matters¶
Earlier resize variants that stretched or filled shorter sides with synthetic content produced artifacts near the borders. With the current combination—conservative dark-border cropping followed by aspect-ratio–preserving resize—the preprocessing feels the most “honest”: geometry is preserved, padding is minimal, and the model sees consistent inputs.
Edge cases to watch¶
- Very unusual aspect ratios can lead to slightly thicker pads (still acceptable for batching).
- If lungs are extremely close to the frame, a few peripheral pixels may be lost during cropping; the
max_crop_fracand smallpadlimit this. - Any mismatch between the “final resize preview” and the “cached PNG” would indicate a pipeline bug; none were observed in the sampled examples.
# --- Quick debugger: Original → Dark-crop → Final resize preview → Cached ---
plt.rcParams["figure.dpi"] = 120 # sharper preview for debugging
# Must match the cache settings you used when building the cache
# (CROP_KW, _RESIZE_MODE, _PAD_MODE, _PAD_VALUE, RESIZE_WH, SCHEME defined earlier)
CACHE_DIR = _cache_root(
resize_wh=RESIZE_WH,
crop_kwargs=CROP_KW,
scheme=SCHEME,
resize_mode=_RESIZE_MODE,
pad_mode=_PAD_MODE,
pad_value=_PAD_VALUE,
)
def _dest_for_src(src, split_tag, scheme=SCHEME):
# compute where the cached PNG should live for a given source path
src_p = Path(src)
if scheme == "mirror":
rel = _rel_mirror_under_root(src_p, ROOT)
else:
rel = _rel_hashed(src_p)
# put under <cache>/<split>/...
rel = Path(split_tag) / (rel if scheme == "mirror" else rel.name)
return (CACHE_DIR / rel).as_posix()
def _final_resize_preview(img):
# mirror the pipeline exactly (same helper the cache uses)
return _apply_final_resize(img, out=RESIZE_WH)
def debug_stage1_on_src(src_path, split_tag="train"):
# end-to-end peek: original → crop → final resize → cached file
g0 = cv2.imread(src_path, cv2.IMREAD_GRAYSCALE)
assert g0 is not None, src_path
# 1) same dark-crop used for cache
g1, _ = crop_dark_frame(g0, **CROP_KW)
# 2) preview the same final resize (cover+crop OR fit+pad)
g2 = _final_resize_preview(g1)
# 3) read the cached PNG (what the dataloader actually uses)
dst = _dest_for_src(src_path, split_tag, scheme=SCHEME)
gc = cv2.imread(dst, cv2.IMREAD_GRAYSCALE)
# Optional: pad mask visualization if using fit_pad
mask = None
if _RESIZE_MODE == "fit_pad":
th, tw = RESIZE_WH
H, W = g1.shape[:2]
scale = min(th / H, tw / W)
newH, newW = int(round(H * scale)), int(round(W * scale))
top = (th - newH) // 2; left = (tw - newW) // 2
mask = np.zeros((th, tw), dtype=np.uint8)
mask[top:top+newH, left:left+newW] = 1
def show(ax, img, title):
# keep contrast consistent across uint8/float previews
vmin, vmax = (0, 255) if img.dtype == np.uint8 else (0, 1)
ax.imshow(img, cmap='gray', vmin=vmin, vmax=vmax)
ax.set_title(title); ax.axis('off')
ncols = 5 if mask is not None else 4
fig, ax = plt.subplots(1, ncols, figsize=(13.5 if ncols==5 else 12, 3.2), constrained_layout=True)
show(ax[0], g0, "Original")
show(ax[1], g1, "Dark-cropped")
show(ax[2], g2, "Final resize preview")
show(ax[3], gc if gc is not None else g2, "Cached PNG" if gc is not None else "Cached (missing)")
if mask is not None:
ax[4].imshow(mask, cmap='gray', vmin=0, vmax=1); ax[4].set_title("Pad mask"); ax[4].axis('off')
fig.suptitle(os.path.basename(src_path), fontsize=11)
plt.show()
# sample a few originals and inspect their cached versions (train split)
for (src, _y) in random.sample(train_items_raw, 3):
debug_stage1_on_src(src, split_tag="train")
1.3.5 Loss, optimizer & metrics¶
Loss¶
Since this is a binary classification task, the model outputs a single logit and is trained with BCEWithLogitsLoss.
To compensate for the class imbalance (≈3:1 pneumonia:normal in the train split, see §1.3.2), a class-weighted version is used with
pos_weight = NEG / POS.
This increases the penalty for misclassifying the minority class, making the model treat those errors more seriously.
A note on sampling. Instead of oversampling the minority class, which would repeatedly show the same few “normal” scans and risk creating an artificial training distribution, weighting the loss keeps the dataset natural while still addressing imbalance. This felt like a more faithful way of reflecting the real data.
Optimizer & LR schedule¶
Training uses AdamW with a base learning rate around 1e-4 and weight decay 1e-4.
If validation AUROC plateaus, the learning rate is automatically reduced using ReduceLROnPlateau, which halves the LR after two stagnant epochs. This helps the model settle into finer minima once coarse learning slows down.
Metrics¶
- Primary metric: AUROC on the validation set, chosen because it is threshold-free and stable under imbalance.
- Secondary metrics: Accuracy and F1 at a fixed 0.5 threshold, reported for interpretability. These are less robust but give a familiar sense of performance.
Connection to training loop¶
These definitions are pulled together in the training script (§1.3.6). Each run initializes a criterion, optimizer, and scheduler. Training then follows the freeze–unfreeze schedule described in §1.3.1, with early stopping triggered by validation AUROC.
The final model checkpoint (best val AUROC) is later reloaded for test evaluation (§1.3.7).
def make_criterion(pos_weight: float, device=DEVICE):
"""Class-weighted BCE with logits (single-logit output)."""
w = torch.tensor([float(pos_weight)], dtype=torch.float32, device=device) # 1D tensor so per-class weight fits shape
return nn.BCEWithLogitsLoss(pos_weight=w) # handles sigmoid + weighted BCE in one go
def make_optimizer(params, lr=1e-4, weight_decay=1e-4):
"""Pass an iterable of params; we rebuild after (un)freezing."""
return torch.optim.AdamW(list(params), lr=lr, weight_decay=weight_decay) # AdamW for decoupled WD
def make_scheduler(optimizer):
"""ReduceLROnPlateau keyed on val AUROC."""
return torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.5, patience=2 # halve LR if AUROC plateaus
)
def lrs_of(optimizer):
return [pg["lr"] for pg in optimizer.param_groups] # support multi-param-group setups
def print_lrs(tag, scheduler, optimizer):
try: lrs = scheduler.get_last_lr() # works for most schedulers
except Exception: lrs = lrs_of(optimizer) # fallback if not supported
print(f"{tag} LRs:", " ".join(f"{lr:.2e}" for lr in lrs))
def sigmoid_np(x):
return 1.0 / (1.0 + np.exp(-x)) # stable enough for logging on small batches
def compute_epoch_metrics(y_true, y_logits, thresh=0.5):
"""
y_true: 1D {0,1}; y_logits: 1D raw logits
returns: dict {'auroc','acc','f1'}
"""
y_true = np.asarray(y_true, dtype=int)
y_prob = sigmoid_np(np.asarray(y_logits, dtype=float)) # convert logits → probs
y_pred = (y_prob >= thresh).astype(int) # fixed 0.5 cut for reporting
out = {}
try: out["auroc"] = float(roc_auc_score(y_true, y_prob)) # AUROC on probabilities
except Exception: out["auroc"] = float("nan") # edge cases: single-class batch, etc.
out["acc"] = float(accuracy_score(y_true, y_pred))
out["f1"] = float(f1_score(y_true, y_pred, zero_division=0))
return out
1.3.6 Training¶
Overview¶
After switching to the patient-level split (to avoid leakage), training follows a simple two-stage schedule with early stopping. The pipeline reads the stage-2 cached images (§1.3.3), applies train-only flip+rotate (§1.2.2), and optimizes an EfficientNet-B0 with a single-logit head (§1.3.1). Loss/optimizer/scheduler/metrics are defined in §1.3.5.
- Stage 1 (warm-up): freeze the backbone, train the head for a couple of epochs.
- Stage 2 (fine-tune): unfreeze everything and continue with a smaller effective LR (via ReduceLROnPlateau).
- Early stopping: stop if val AUROC doesn’t improve for 3 checks.
- AMP: enabled on GPU to speed up training without changing results.
The code below is organized in small, focused blocks so it’s easier to follow and swap components if needed.
Datasets & dataloaders (read from cache; train-only aug)¶
This block defines a Dataset that loads the cached 224×224 PNGs and applies flip/rotate only for the training split. It then builds train/val/test DataLoaders with deterministic seeding.
# === Dataset + loader helpers (cached stage-2 pipeline) ===
# Expects:
# - cached item lists: train_items, val_items, test_items (from stage-2 cache cell in §1.3.2)
# - aug helpers: maybe_hflip, maybe_rotate
# - compute_epoch_metrics (from §1.3.3)
# - DEVICE, SEED defined earlier
USE_AMP = torch.cuda.is_available() # mixed precision only if CUDA is present
def _seed_worker(worker_id):
# per-worker deterministic seeds for dataloader randomness
import random
import numpy as np
import torch
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed); random.seed(worker_seed)
def _to_float01(u8):
# uint8 → float32 in [0,1]
return (u8.astype(np.float32) / 255.0)
class XRayStage1Dataset(Dataset):
"""
Uses cached 224x224 PNGs (uint8) produced in §1.3.2.
Train-only aug (flip/rotate) applied here; val/test get none.
"""
def __init__(self, items, split="train", aug_strategy="rotate", normalize=True):
self.items = list(items)
self.split = split.lower()
self.aug_strategy = aug_strategy if self.split == "train" else "none"
self.normalize = normalize
def __len__(self): return len(self.items)
def __getitem__(self, idx):
p, y = self.items[idx]
g = cv2.imread(p, cv2.IMREAD_GRAYSCALE) # cached uint8 224x224
if g is None:
raise FileNotFoundError(p)
# train-only augmentation
if self.aug_strategy == "rotate":
g = maybe_rotate(g, max_deg=5, p=0.5)
elif self.aug_strategy == "flip+rotate":
g = maybe_hflip(g, p=0.5)
g = maybe_rotate(g, max_deg=5, p=0.5)
if self.normalize:
g = _to_float01(g) # float32 [0,1]
x = torch.as_tensor(g, dtype=torch.float32).unsqueeze(0) # [1,224,224]
y = torch.tensor([int(y)], dtype=torch.float32) # [1]
return x, y
def make_cached_loaders(train_items, val_items, test_items,
aug_strategy="flip+rotate",
batch_size=32, generator=None):
# worker/pinning config scales with hardware
USE_GPU = torch.cuda.is_available()
num_workers = min(8 if USE_GPU else 2, (os.cpu_count() or 2))
common = dict(batch_size=batch_size, pin_memory=USE_GPU,
num_workers=num_workers, worker_init_fn=_seed_worker)
if num_workers > 0:
common.update(dict(persistent_workers=True, prefetch_factor=4))
if generator is not None:
common["generator"] = generator
train_dl = DataLoader(
XRayStage1Dataset(train_items, "train", aug_strategy=aug_strategy, normalize=True),
shuffle=True, **common
)
val_dl = DataLoader(
XRayStage1Dataset(val_items, "val", aug_strategy="none", normalize=True),
shuffle=False, **common
)
test_dl = DataLoader(
XRayStage1Dataset(test_items, "test", aug_strategy="none", normalize=True),
shuffle=False, **common
)
return train_dl, val_dl, test_dl
Train / evaluate loops¶
Plain PyTorch loops with optional mixed precision. Validation returns AUROC/Acc/F1 from 1.3.5.
def train_one_epoch(model, loader, criterion, optimizer, scaler, device=DEVICE, log_interval=30):
model.train()
running_loss, n = 0.0, 0
seen, start = 0, time.time()
for i, (xb, yb) in enumerate(loader, start=1):
xb = xb.to(device, non_blocking=True)
yb = yb.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
# AMP forward pass (ops run in mixed precision on CUDA if available)
with torch.amp.autocast(device_type='cuda', enabled=USE_AMP):
logits = model(xb)
loss = criterion(logits, yb)
# AMP backward: scale → step → update
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
bs = xb.size(0)
running_loss += loss.item() * bs; n += bs; seen += bs
if i % log_interval == 0 or i == len(loader):
dt = time.time() - start
print(f" batch {i:>4}/{len(loader)} avg_loss={running_loss/max(1,n):.4f} "
f"{i/dt:.2f} it/s {seen/dt:.1f} samp/s")
return running_loss / max(1, n)
@torch.no_grad()
def evaluate(model, loader, criterion, device=DEVICE, log_interval=0):
model.eval()
running_loss, n = 0.0, 0
all_logits, all_true = [], []
seen, start = 0, time.time()
for i, (xb, yb) in enumerate(loader, start=1):
xb = xb.to(device, non_blocking=True)
yb = yb.to(device, non_blocking=True)
logits = model(xb)
loss = criterion(logits, yb)
bs = xb.size(0)
running_loss += loss.item() * bs; n += bs
# keep raw logits/labels for AUROC/ACC/F1 later
all_logits.append(logits.squeeze(1).detach().cpu())
all_true.append(yb.squeeze(1).detach().cpu()); seen += bs
if log_interval and (i % log_interval == 0 or i == len(loader)):
dt = time.time() - start
print(f" [val] batch {i:>4}/{len(loader)} avg_loss={running_loss/max(1,n):.4f} "
f"{i/dt:.2f} it/s {seen/dt:.1f} samp/s")
val_loss = running_loss / max(1, n)
y_true = torch.cat(all_true).numpy().astype(np.float32)
y_logits = torch.cat(all_logits).numpy().astype(np.float32)
metrics = compute_epoch_metrics(y_true, y_logits, thresh=0.5) # consistent 0.5 cut
return val_loss, metrics
Freeze / unfreeze helpers¶
Small utilities to (1) train the head first, then (2) fine-tune the whole network.
def freeze_backbone_only(model, head_names=("classifier","fc","head")):
"""
Freeze everything, then unfreeze any module whose name contains one of head_names.
Adjust head_names to match your build_model() head naming.
"""
# blanket-freeze first so we start from a known state
for p in model.parameters():
p.requires_grad = False
# unfreeze any submodule whose qualified name hints it's the head
found = []
for name, module in model.named_modules():
if any(h in name.lower() for h in head_names):
for p in module.parameters():
p.requires_grad = True
found.append(name)
# fallback for models where the head is attached directly as an attribute
if not found:
for attr in head_names:
if hasattr(model, attr):
for p in getattr(model, attr).parameters():
p.requires_grad = True
found.append(attr)
# quick audit: how many params are trainable now, and which modules were targeted
tr = sum(p.numel() for p in model.parameters() if p.requires_grad)
tot = sum(p.numel() for p in model.parameters())
print(f"[freeze] trainable={tr}/{tot} params | head modules: {found or 'NONE (check head_names)'}")
def unfreeze_all(model):
"""Enable grad for all params."""
for p in model.parameters():
p.requires_grad = True
# confirm everything is now trainable
tr = sum(p.numel() for p in model.parameters() if p.requires_grad)
tot = sum(p.numel() for p in model.parameters())
print(f"[unfreeze] trainable={tr}/{tot} params")
Single training run¶
This cell wires everything together: loaders, model, criterion/optimizer/scheduler (§1.3.5), AMP scaler, and checkpointing by best val AUROC. After training, it reloads the best weights and reports metrics on val/test.
set_global_seed(SEED) # lock global RNGs for reproducibility
gen = torch.Generator().manual_seed(SEED) # DataLoader-level generator
train_dl, val_dl, test_dl = make_cached_loaders(
train_items, val_items, test_items,
aug_strategy="flip+rotate", # or "rotate" / "none"
batch_size=32, generator=gen
)
# 1) Build model + criterion
model = build_model(backbone="efficientnet_b0", pretrained=True, in_chans=1).to(DEVICE)
criterion = make_criterion(POS_WEIGHT, device=DEVICE) # class-weighted BCE-with-logits
# 2) STAGE 1 — freeze backbone, train head; THEN build optimizer/scheduler
freeze_backbone_only(model, head_names=("classifier","fc","head")) # match your build_model naming
assert any(p.requires_grad for p in model.parameters()), "No trainable params after freeze!"
optimizer = make_optimizer((p for p in model.parameters() if p.requires_grad),
lr=1e-4, weight_decay=1e-4) # optimize head only
scheduler = make_scheduler(optimizer) # ReduceLROnPlateau on val AUROC
# AMP-compatible GradScaler across torch 2.2.x .. 2.4.x+
if hasattr(torch, "amp") and hasattr(torch.amp, "GradScaler"):
# Newer API (2.3/2.4+): supports device="cuda"
scaler = torch.amp.GradScaler(device="cuda", enabled=torch.cuda.is_available())
else:
# Older API (2.2.x and earlier): lives under torch.cuda.amp, no device arg
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
ckpt_dir = Path(os.environ.get("CXR_SAVE_DIR", f"{PROJ_ROOT}/checkpoints")).expanduser()
ckpt_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = ckpt_dir / 'best_model.pt' # best-by-val-AUROC checkpoint
best_auroc = -float("inf")
# --- Stage 1: train head only ---
for epoch in range(1, 2+1):
tr = train_one_epoch(model, train_dl, criterion, optimizer, scaler, log_interval=30)
vl, m = evaluate(model, val_dl, criterion)
scheduler.step(m["auroc"]) # plateau scheduler expects a metric
print(f"[Frozen {epoch}/2] train_loss={tr:.4f} val_loss={vl:.4f} AUROC={m['auroc']:.3f} Acc={m['acc']:.3f} F1={m['f1']:.3f}")
if m["auroc"] > best_auroc:
best_auroc = m["auroc"]
torch.save(model.state_dict(), ckpt_path) # keep the current best
# 3) STAGE 2 — unfreeze all; REBUILD optimizer/scheduler
unfreeze_all(model) # fine-tune entire network
optimizer = make_optimizer(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = make_scheduler(optimizer)
epochs_no_improve = 0
for epoch in range(1, 3+1):
tr = train_one_epoch(model, train_dl, criterion, optimizer, scaler)
vl, m = evaluate(model, val_dl, criterion)
scheduler.step(m["auroc"])
print(f"[Unfrozen {epoch}/3] train_loss={tr:.4f} val_loss={vl:.4f} AUROC={m['auroc']:.3f} Acc={m['acc']:.3f} F1={m['f1']:.3f}")
if m["auroc"] > best_auroc:
best_auroc = m["auroc"]; epochs_no_improve = 0
torch.save(model.state_dict(), ckpt_path) # update best checkpoint
else:
epochs_no_improve += 1
if epochs_no_improve >= 3:
print("Early stopping (no AUROC improvement).") # simple patience on AUROC
break
# 4) Evaluate best on test
best_model = build_model(backbone="efficientnet_b0", pretrained=False, in_chans=1).to(DEVICE)
best_model.load_state_dict(torch.load(str(ckpt_path), map_location=DEVICE)) # strict arch match
_, val_metrics = evaluate(best_model, val_dl, criterion)
_, test_metrics = evaluate(best_model, test_dl, criterion)
print("BEST (val):", val_metrics)
print("TEST:", test_metrics)
print("Checkpoint saved to:", ckpt_path)
[freeze] trainable=637821/4008253 params | head modules: ['features.1.0.block.1.fc1', 'features.1.0.block.1.fc2', 'features.2.0.block.2.fc1', 'features.2.0.block.2.fc2', 'features.2.1.block.2.fc1', 'features.2.1.block.2.fc2', 'features.3.0.block.2.fc1', 'features.3.0.block.2.fc2', 'features.3.1.block.2.fc1', 'features.3.1.block.2.fc2', 'features.4.0.block.2.fc1', 'features.4.0.block.2.fc2', 'features.4.1.block.2.fc1', 'features.4.1.block.2.fc2', 'features.4.2.block.2.fc1', 'features.4.2.block.2.fc2', 'features.5.0.block.2.fc1', 'features.5.0.block.2.fc2', 'features.5.1.block.2.fc1', 'features.5.1.block.2.fc2', 'features.5.2.block.2.fc1', 'features.5.2.block.2.fc2', 'features.6.0.block.2.fc1', 'features.6.0.block.2.fc2', 'features.6.1.block.2.fc1', 'features.6.1.block.2.fc2', 'features.6.2.block.2.fc1', 'features.6.2.block.2.fc2', 'features.6.3.block.2.fc1', 'features.6.3.block.2.fc2', 'features.7.0.block.2.fc1', 'features.7.0.block.2.fc2', 'classifier', 'classifier.0', 'classifier.1']
batch 30/148 avg_loss=0.3338 2.66 it/s 85.0 samp/s
batch 60/148 avg_loss=0.3246 4.53 it/s 145.1 samp/s
batch 90/148 avg_loss=0.3033 5.93 it/s 189.8 samp/s
batch 120/148 avg_loss=0.2851 6.98 it/s 223.2 samp/s
batch 148/148 avg_loss=0.2695 5.72 it/s 182.5 samp/s
[Frozen 1/2] train_loss=0.2695 val_loss=0.1949 AUROC=0.965 Acc=0.881 F1=0.914
batch 30/148 avg_loss=0.1788 13.62 it/s 435.7 samp/s
batch 60/148 avg_loss=0.1804 13.84 it/s 442.9 samp/s
batch 90/148 avg_loss=0.1743 12.27 it/s 392.7 samp/s
batch 120/148 avg_loss=0.1710 12.35 it/s 395.1 samp/s
batch 148/148 avg_loss=0.1635 12.28 it/s 391.6 samp/s
[Frozen 2/2] train_loss=0.1635 val_loss=0.1393 AUROC=0.977 Acc=0.940 F1=0.959
[unfreeze] trainable=4008253/4008253 params
batch 30/148 avg_loss=0.1276 4.28 it/s 137.1 samp/s
batch 60/148 avg_loss=0.1149 5.98 it/s 191.4 samp/s
batch 90/148 avg_loss=0.1048 7.19 it/s 230.0 samp/s
batch 120/148 avg_loss=0.0943 8.00 it/s 255.9 samp/s
batch 148/148 avg_loss=0.0922 6.55 it/s 208.7 samp/s
[Unfrozen 1/3] train_loss=0.0922 val_loss=0.0725 AUROC=0.990 Acc=0.957 F1=0.970
batch 30/148 avg_loss=0.0538 11.00 it/s 352.2 samp/s
batch 60/148 avg_loss=0.0505 11.33 it/s 362.6 samp/s
batch 90/148 avg_loss=0.0572 11.44 it/s 366.2 samp/s
batch 120/148 avg_loss=0.0568 10.78 it/s 345.0 samp/s
batch 148/148 avg_loss=0.0561 10.62 it/s 338.7 samp/s
[Unfrozen 2/3] train_loss=0.0561 val_loss=0.0583 AUROC=0.994 Acc=0.964 F1=0.976
batch 30/148 avg_loss=0.0578 10.75 it/s 344.1 samp/s
batch 60/148 avg_loss=0.0495 11.16 it/s 357.0 samp/s
batch 90/148 avg_loss=0.0437 11.23 it/s 359.3 samp/s
batch 120/148 avg_loss=0.0428 10.29 it/s 329.4 samp/s
batch 148/148 avg_loss=0.0438 10.54 it/s 336.0 samp/s
[Unfrozen 3/3] train_loss=0.0438 val_loss=0.0658 AUROC=0.992 Acc=0.964 F1=0.975
BEST (val): {'auroc': 0.9942022643986217, 'acc': 0.9641509433962264, 'f1': 0.9755469755469756}
TEST: {'auroc': 0.9957658630202711, 'acc': 0.9621087314662273, 'f1': 0.9709228824273072}
Checkpoint saved to: /content/drive/MyDrive/code/cxr_assignment/checkpoints/best_model.pt
1.3.7 Results, checkpoint, and sanity check¶
Validation and test performance.
On the patient-level split, the final run achieved:
- Validation: AUROC 0.9942, Acc 0.9642, F1 0.9755
- Test: AUROC 0.9958, Acc 0.9621, F1 0.9709
These numbers are in line with earlier runs and stay strong across the preprocessing choices used.
Checkpoint selection (brief).
During fine-tuning, validation improved, then began to soften slightly near the end—typical early signs of overfitting. I therefore report results from the best validation AUROC checkpoint (saved before that softening), rather than the very last epoch.
Why such high results need caution.
Even with patient-level splits and the updated preprocessing, performance remains extremely high. This suggests the model may still benefit from dataset-specific cues (e.g., scanner/processing characteristics, borders, acquisition patterns) in addition to disease-relevant lung features. Because train/val/test all come from the same source distribution, near-perfect within-dataset generalization can still be fragile on a truly independent dataset.
Sanity check: shuffled-label AUC.
To verify the evaluation pipeline, I also compute a shuffled-label AUROC by randomly permuting test labels. As expected, this yields ~0.50, which supports that metric calculation, label alignment, and split handling are functioning as intended (though it does not address shortcut learning).
# Run after training so these names exist in the session:
# build_model, make_loaders, compute_epoch_metrics, make_criterion, POS_WEIGHT, DEVICE
# set_global_seed / SEED (optional), XRDataset, DataLoader, etc.
CHOSEN = "best_model"
CKPT_DIR = f"{PROJ_ROOT}/checkpoints"
CKPT_PATH = os.path.join(CKPT_DIR, f"{CHOSEN}.pt")
# Common run config
COMMON_RUN_KW = dict(
resize_to_wh=(224, 224),
normalize=True,
detect_kwargs={"border_ratio": 0.10, "pct": 99.5, "dilate_edge": 7},
crop_kwargs={"border_ratio": 0.14, "se_frac": 0.025},
)
# Variant-specific knobs used for this run
RUN_KW = dict(COMMON_RUN_KW)
def make_loaders(run_kw=None, batch_size=32, generator=None):
"""
Thin wrapper around make_cached_loaders; pulls aug_strategy from run_kw.
"""
if run_kw is None:
run_kw = {}
aug = run_kw.get("aug_strategy", "none")
return make_cached_loaders(
train_items, val_items, test_items,
aug_strategy=aug, batch_size=batch_size, generator=generator
)
def _sigmoid(z):
z = np.clip(z, -50, 50) # avoid exp overflow
return 1.0 / (1.0 + np.exp(-z))
def _load_model(path):
# same arch as training, no pretrained weights
model = build_model(backbone="efficientnet_b0", pretrained=False, in_chans=1).to(DEVICE)
state = torch.load(path, map_location=DEVICE)
model.load_state_dict(state)
return model
@torch.no_grad()
def _collect_logits(model, loader):
# forward pass over loader; keep raw logits + labels
model.eval()
logits_all, y_all = [], []
for xb, yb in loader:
xb = xb.to(DEVICE, non_blocking=True)
logits = model(xb).squeeze(1).detach().cpu().numpy()
y = yb.squeeze(1).detach().cpu().numpy()
logits_all.append(logits)
y_all.append(y)
y_true = np.concatenate(y_all).astype(np.float32)
y_logits = np.concatenate(logits_all).astype(np.float32)
return y_true, y_logits
def compute_shuf_auc(y_true, y_logits, seed=123):
# baseline sanity: shuffle labels, keep predictions fixed
rng = np.random.default_rng(seed)
y_perm = rng.permutation(y_true)
y_prob = _sigmoid(y_logits)
return float(roc_auc_score(y_perm, y_prob))
# ---- Run: reload checkpoint, evaluate on TEST, then SHUF AUC ----
assert os.path.isfile(CKPT_PATH), f"Checkpoint not found: {CKPT_PATH}"
# loaders with selected config
_, _, test_dl = make_loaders(RUN_KW, batch_size=32, generator=None)
# load weights and score
model = _load_model(CKPT_PATH)
criterion = make_criterion(POS_WEIGHT, device=DEVICE)
y_true, y_logits = _collect_logits(model, test_dl)
metrics = compute_epoch_metrics(y_true, y_logits, thresh=0.5)
# shuffled-label sanity AUC
SEED_LOCAL = SEED if 'SEED' in globals() else 123
shuf_auc = compute_shuf_auc(y_true, y_logits, seed=SEED_LOCAL)
print(f"Checkpoint: {CKPT_PATH}")
print(f"TEST: AUC={metrics['auroc']:.3f} Acc={metrics['acc']:.3f} F1={metrics['f1']:.3f}")
print(f"SHUF AUC (labels permuted, predictions fixed): {shuf_auc:.3f} (expect ~0.50)")
Checkpoint: /content/drive/MyDrive/code/cxr_assignment/checkpoints/best_model.pt TEST: AUC=0.996 Acc=0.962 F1=0.971 SHUF AUC (labels permuted, predictions fixed): 0.528 (expect ~0.50)
2 Occlusion function¶
2.1 Setup for occlusion¶
To start with the occlusion experiments, I reload the best checkpoint from Part 1 and build a lightweight test loader from the stage-2 cached images. The idea is to take correctly classified test cases and see how the model’s confidence changes when parts of the image are hidden.
For consistency, I select eight test images where the model predicted correctly: four with very high confidence (≥0.95) and four with more moderate confidence (0.80–0.90). This way the results include both “easy” and “less certain” examples, which makes the occlusion test more informative.
# Expects from Part 1:
# - DEVICE, build_model
# - test_items: list[(cached_png_path, label)] from stage-2 cache
# - XRayStage1Dataset (normalize=True), DataLoader
def make_test_loader(items, batch_size=128):
# cached PNGs, no aug, normalized to [0,1]
ds = XRayStage1Dataset(items, split="test", aug_strategy="none", normalize=True)
use_gpu = torch.cuda.is_available()
return DataLoader(ds, batch_size=batch_size, shuffle=False,
pin_memory=use_gpu, num_workers=min(8 if use_gpu else 2, (os.cpu_count() or 2)))
# 1) Load best checkpoint
CKPT_DIR = Path(os.environ.get("CXR_SAVE_DIR", "/content/drive/MyDrive/code/cxr_assignment/checkpoints")).expanduser()
CKPT_PATH = CKPT_DIR / "best_model.pt"
assert CKPT_PATH.exists(), f"Missing checkpoint: {CKPT_PATH}"
model_occ = build_model(backbone="efficientnet_b0", pretrained=False, in_chans=1).to(DEVICE)
state = torch.load(str(CKPT_PATH), map_location=DEVICE)
model_occ.load_state_dict(state)
model_occ.eval()
# 2) Collect test probs/preds
test_dl_eval = make_test_loader(test_items, batch_size=128)
def _sigmoid_t(t): return 1/(1+torch.exp(-t))
all_probs, all_true, all_pred = [], [], []
with torch.no_grad():
for xb, yb in test_dl_eval:
xb = xb.to(DEVICE); yb = yb.to(DEVICE)
logits = model_occ(xb).squeeze(1)
probs = _sigmoid_t(logits) # probabilities per sample
preds = (probs >= 0.5).long() # hard 0.5 cut
all_probs.append(probs.detach().cpu())
all_true.append(yb.long().squeeze(1).detach().cpu())
all_pred.append(preds.detach().cpu())
probs = torch.cat(all_probs).numpy().astype(np.float32)
ytrue = torch.cat(all_true).numpy().astype(np.int64)
ypred = torch.cat(all_pred).numpy().astype(np.int64)
# 3) Pick 8 correctly classified test samples:
# - 4 with very high confidence (>=0.95)
# - 4 with moderate-high confidence (>=0.80 and <0.90)
hi = [i for i in range(len(test_items)) if (ypred[i]==ytrue[i]) and (probs[i] >= 0.95)]
mid = [i for i in range(len(test_items)) if (ypred[i]==ytrue[i]) and (0.80 <= probs[i] < 0.90)]
random.Random(42).shuffle(hi); random.Random(42).shuffle(mid) # deterministic pick order
pick_idx = (hi[:4] + mid[:4])[:8]
assert pick_idx, "No suitable test samples found."
picked = [(test_items[i][0], int(test_items[i][1]), float(probs[i])) for i in pick_idx]
print("Selected images (path tail, label, prob):")
for pth, y, pr in picked:
print(os.path.basename(pth), y, f"{pr:.3f}")
Selected images (path tail, label, prob): person294_bacteria_1386.png 1 0.961 person1352_bacteria_3444.png 1 0.999 person294_virus_611.png 1 0.999 person1063_virus_1765.png 1 0.986 person554_virus_1094.png 1 0.857 person520_virus_1039.png 1 0.869 person1273_virus_2191.png 1 0.841 person1588_virus_2762.png 1 0.835
2.1.1 Dataset mean¶
When occluding pixels, we need to replace them with some “neutral” value. A simple choice is black (0.0), but that can look unnatural compared to the rest of the scan. Instead, I compute the dataset mean intensity from the training cache. This gives a realistic filler that matches the overall brightness of the images. It should reduce the risk of the model reacting too strongly to the filler itself.
def dataset_mean_from_cached(items, max_samples=None):
"""
Compute mean pixel value over cached grayscale PNGs.
Returns mean in [0,1].
"""
s = 0.0
n = 0
for i, (pth, _) in enumerate(items):
if max_samples is not None and i >= max_samples:
break # optional cap for quick pass
a = cv2.imread(pth, cv2.IMREAD_GRAYSCALE) # uint8 [0..255]
if a is None:
continue # skip unreadables
s += float(a.sum()) # sum of all pixels in this image
n += int(a.size) # number of pixels
assert n > 0, "No cached images found to compute dataset mean."
return (s / (n * 255.0)) # normalize to [0,1]
# Prefer the training cache for the mean (typical convention)
MEAN_DATASET = dataset_mean_from_cached(train_items, max_samples=None) # cap for speed if needed
print(f"[Occlusion] Dataset mean over cached files: {MEAN_DATASET:.6f}")
[Occlusion] Dataset mean over cached files: 0.575157
2.2 occlusion_drop function¶
The occlusion function measures how much the model’s probability for the true class drops when a part of the image is hidden.
Steps:¶
- Run the model on the original image to get the probability for the correct label.
- Apply a binary mask to hide part of the image, filling with a neutral value (by default the per-image mean).
- Run the model again on the occluded image.
- Return the difference in probability (“drop”).
The function works for both NumPy arrays and PyTorch tensors, handles grayscale input in [0,1], and uses torch.no_grad() for efficiency.
def _to_float01_hw(x):
# ensure [H,W] float32 in [0,1]; accept [1,H,W] or uint8
x = np.asarray(x)
if x.ndim == 3 and x.shape[0] == 1: # [1,H,W] -> [H,W]
x = x[0]
x = x.astype(np.float32)
if x.max() > 1.0 or x.min() < 0.0: # allow uint8 input
x = x / 255.0
return x # [H,W] in [0,1]
@torch.no_grad()
def occlusion_drop(img, mask, model, true_label:int, neutral:float|str|None=None, device=DEVICE):
"""
img: [H,W] or [1,H,W] in [0,1] (np or torch)
mask: [H,W], 1=keep, 0=occlude
model: eval-mode, single-logit output (binary)
neutral: None | "dataset" | numeric (e.g. 0.0). None => per-image mean.
returns: (drop, p_true_orig, p_true_occl, x_occ[H,W] in [0,1])
"""
# normalize and drop channel if present
x = img.detach().cpu().numpy() if isinstance(img, torch.Tensor) else img
x = _to_float01_hw(x)
H, W = x.shape
# mask → float in {0,1}
m = (np.asarray(mask, dtype=np.uint8) > 0).astype(np.float32)
# baseline filler (constant): numeric / dataset mean / per-image mean
if isinstance(neutral, (int, float)):
filler = np.full_like(x, float(neutral), dtype=np.float32)
elif neutral == "dataset":
filler = np.full_like(x, float(MEAN_DATASET), dtype=np.float32)
else: # None → per-image mean
filler = np.full_like(x, float(x.mean()), dtype=np.float32)
# compose occluded image
x_occ = (x * m) + (1.0 - m) * filler # [H,W], [0,1]
# model expects BCHW
xt = torch.from_numpy(x [None, None, ...]).to(device) # [1,1,H,W]
xot = torch.from_numpy(x_occ[None, None, ...]).to(device)
# probs pre/post occlusion
p0 = torch.sigmoid(model(xt ).squeeze(1)).item()
p1 = torch.sigmoid(model(xot).squeeze(1)).item()
# prob of the true label; drop = how much it fell
p_true_orig = p0 if true_label == 1 else (1.0 - p0)
p_true_occl = p1 if true_label == 1 else (1.0 - p1)
drop = float(p_true_orig - p_true_occl)
return drop, float(p_true_orig), float(p_true_occl), x_occ
2.3 Mask shapes¶
To probe different regions, I define simple binary masks:
- Quadrants (Q1–Q4): each hides one quarter of the image, giving a coarse sense of which side matters more.
- Center square: hides the middle of the lungs; included to check whether the center alone is crucial.
- Ring (“donut”): the opposite, hiding the borders and keeping only the center, to test for edge-based cues.
These are simple to generate and cover both central and peripheral structures. They don’t explain everything about what the network looks at, but they provide a first impression of sensitivity.
def make_center_mask(H, W, size=None, frac=None):
"""
Create a [H,W] mask with a square occluder in the center.
- size: side length in pixels (int)
- frac: fraction of min(H,W) for side length (e.g., 0.2 -> 20%)
Exactly one of {size, frac} must be given.
"""
assert (size is None) ^ (frac is None), "Pass either size or frac." # XOR: pick one
side = int(round(min(H, W) * float(frac))) if frac is not None else int(size)
side = max(1, min(side, min(H, W))) # clamp to [1, min(H,W)]
y0 = (H - side)//2; x0 = (W - side)//2 # centered square
m = np.ones((H, W), np.uint8)
m[y0:y0+side, x0:x0+side] = 0 # 0 = occlude, 1 = keep
return m, side # return the actual side we used (useful for titles/tables)
2.4 Running occlusion tests¶
With the function and masks ready, I run the occlusion tests on the eight selected images. For each case, the display shows:
- the original image,
- the mask overlay,
- the occluded image, and
- the drop value (original probability → occluded probability).
Using both a small mask (10×10) and a large mask (≈96×96) helps compare local versus more global occlusions.
The expectation is straightforward: if hiding a region makes the model much less confident, that region likely contributed to the decision. The drop values therefore give a rough measure of importance, although they shouldn’t be over-interpreted as precise explanations.
plt.rcParams["figure.dpi"] = 110 # sharper figures
def draw_center_square_overlay(img01, side, color=(1,0,0), lw=2):
"""Return RGB overlay with a centered square outline."""
H, W = img01.shape
y0 = (H - side)//2; x0 = (W - side)//2
rgb = np.dstack([img01]*3).copy() # grayscale → RGB for plotting
# rectangle is drawn later with matplotlib (keeps this helper lightweight)
return rgb, (x0, y0, side, side)
def demo_two_sizes_row(path, label, neutral="dataset", sizes=(10, 96)):
g = cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32)/255.0 # [0,1]
H, W = g.shape
# compute occlusions for both sizes
results = []
for sz in sizes:
# centered square mask (0=occlude)
m = np.ones((H,W), np.uint8)
y0 = (H - sz)//2; x0 = (W - sz)//2
m[y0:y0+sz, x0:x0+sz] = 0
drop, p0, p1, x_occ = occlusion_drop(g, m, model_occ, true_label=label,
neutral=neutral, device=DEVICE)
results.append((sz, drop, p0, p1, x_occ, (x0, y0, sz, sz)))
# layout: Original | (10) Mask | (10) Occluded | (96) Mask | (96) Occluded
fig, axs = plt.subplots(1, 5, figsize=(12.5, 2.6), constrained_layout=True)
fig.suptitle(f"{os.path.basename(path)} | y={label}", fontsize=10)
# col 1: original
axs[0].imshow(g, cmap='gray', vmin=0, vmax=1); axs[0].set_title("Original"); axs[0].axis('off')
# small helper to draw a red square overlay
def show_overlay(ax, img01, box, title):
ax.imshow(img01, cmap='gray', vmin=0, vmax=1)
x0, y0, s, s2 = box[0], box[1], box[2], box[2]
import matplotlib.patches as patches
rect = patches.Rectangle((x0, y0), s, s2, linewidth=2, edgecolor='r', facecolor='none')
ax.add_patch(rect)
ax.set_title(title); ax.axis('off')
# col 2-3: small mask
sz, drop, p0, p1, x_occ, box = results[0]
show_overlay(axs[1], g, box, f"Mask {sz}px")
axs[2].imshow(x_occ, cmap='gray', vmin=0, vmax=1); axs[2].axis('off')
axs[2].set_title(f"Occluded {sz}px\nΔ={drop:.3f} (p:{p0:.3f}→{p1:.3f})", fontsize=8)
# col 4-5: large mask
sz, drop, p0, p1, x_occ, box = results[1]
show_overlay(axs[3], g, box, f"Mask {sz}px")
axs[4].imshow(x_occ, cmap='gray', vmin=0, vmax=1); axs[4].axis('off')
axs[4].set_title(f"Occluded {sz}px\nΔ={drop:.3f} (p:{p0:.3f}→{p1:.3f})", fontsize=8)
plt.show()
# run for selected images (tiny + big square)
for pth, y, _ in picked[:8]:
demo_two_sizes_row(pth, y, neutral="dataset", sizes=(10, 96))
3 Superpixels¶
3.1 Setup and baseline choices¶
Goal here is to break each image into small, locally coherent regions (“superpixels”) so later we can probe or occlude them as units. I start with SLIC (Simple Linear Iterative Clustering) using the assignment’s baseline settings: about 100 segments, compactness = 10, and sigma = 1. Since the images are grayscale, the input is duplicated into three channels for SLIC, and LAB conversion is left off by default (I keep the option to enable it if needed).
These parameters are a reasonable first pass: ~100 segments give enough detail without being too fragmented; compactness balances color/brightness similarity with spatial proximity; and a small sigma smooths noise slightly before clustering. I keep everything simple and consistent across images so the later comparisons are easier to read.
# Baseline params suggested by the assignment
SLIC_N_SEGMENTS = 100
SLIC_COMPACTNESS = 10.0
SLIC_SIGMA = 1.0
USE_LAB = False # set True if you want to run SLIC in LAB space (optional)
plt.rcParams["figure.dpi"] = 110
3.2 Generating superpixels (SLIC)¶
Given a grayscale image in [0,1], I create a 3-channel version (stacked copies) and run SLIC to obtain a label map seg with values 0..K-1. Each label corresponds to one superpixel. I also return an RGB version for plotting overlays.
Notes from trying a few settings:
- Segment count: going much lower (e.g., 50) merged anatomically different areas; going much higher (e.g., 150+) got a bit noisy around ribs. 100 felt like a workable middle ground.
- Compactness: values far below 10 tended to snake along edges too aggressively; far above 10 produced blocky tiles. 10 looked balanced on my samples.
- LAB option: for these grayscale inputs, running in LAB didn’t noticeably change boundaries, so I kept it off to keep the pipeline minimal.
def compute_superpixels_from_gray(gray01_hw,
n_segments=SLIC_N_SEGMENTS,
compactness=SLIC_COMPACTNESS,
sigma=SLIC_SIGMA,
use_lab=USE_LAB):
"""
gray01_hw: float32 [H,W] in [0,1]
returns:
seg: [H,W] int labels (0..K-1)
rgb: [H,W,3] float image for visualization
"""
g = gray01_hw.astype(np.float32) # ensure float32
rgb = np.dstack([g, g, g]) # SLIC wants multi-channel input
img_for_slic = rgb2lab(rgb) if use_lab else rgb # optional LAB conversion
seg = slic(img_for_slic,
n_segments=int(n_segments),
compactness=float(compactness),
sigma=float(sigma),
start_label=0,
channel_axis=-1)
return seg.astype(np.int32), rgb # labels + the RGB used for viz
3.3 Binary masks per segment¶
For later experiments (e.g., region-wise occlusion), I need a mask for each superpixel. From the seg label map, I extract the unique labels and build a list of binary masks (1 = keep, 0 = occlude) for each region. This makes it easy to apply the exact same occlusion procedure as in Part 2, but now at the level of superpixels rather than fixed squares.
def masks_from_segments(seg):
"""
seg: [H,W] labels 0..K-1
returns:
labels: np.ndarray of unique labels
masks: list of uint8 masks (1=keep, 0=occlude) per label
"""
labels = np.unique(seg) # sorted unique segment ids
masks = [(seg == k).astype(np.uint8) for k in labels] # one binary mask per id
return labels, masks
3.4 Visualizing boundaries¶
To sanity-check the segmentation, I overlay SLIC boundaries on the original image. This helps see whether regions roughly follow anatomical structures (e.g., lung fields, ribs, heart border). The expectation isn’t perfect clinical boundaries—just reasonable adherence so that region-level tests aren’t arbitrary.
I run the same visualization on the selected test images from Part 2 (the ones used for occlusion). Keeping the sample consistent makes it easier to relate superpixel results to the earlier mask-based occlusions.
def show_superpixel_boundaries(path, n_segments=SLIC_N_SEGMENTS,
compactness=SLIC_COMPACTNESS, sigma=SLIC_SIGMA,
use_lab=USE_LAB):
g = cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32)/255.0 # [0,1] grayscale
seg, rgb = compute_superpixels_from_gray(g, n_segments, compactness, sigma, use_lab)
K = int(seg.max() + 1) # number of segments
fig, axs = plt.subplots(1, 2, figsize=(8.8, 3.0), constrained_layout=True)
axs[0].imshow(g, cmap='gray', vmin=0, vmax=1); axs[0].set_title("Original"); axs[0].axis('off')
overlay = mark_boundaries(rgb, seg, color=(1,0,0), mode='thick') # red boundaries on top
axs[1].imshow(overlay); axs[1].set_title(f"SLIC boundaries (K={K})"); axs[1].axis('off')
plt.show()
return seg # reuse in Part (iv)
# Show boundaries for each selected image (5–10 total)
SLIC_N_SEGMENTS, SLIC_COMPACTNESS, SLIC_SIGMA = 100, 10.0, 1.0 # as suggested baseline
for path, label, _ in picked[:8]:
seg = show_superpixel_boundaries(path, SLIC_N_SEGMENTS, SLIC_COMPACTNESS, SLIC_SIGMA, USE_LAB)
3.4.1 Observation¶
The number of segments varied slightly (e.g., 94 or 98 instead of exactly 100), which is expected with SLIC. The boundaries generally follow the lung outlines and rib structures, though they are not perfect clinical regions. This seems good enough for the later occlusion experiments, since the goal is approximate region partitioning rather than exact anatomy.
3.4.2 Quick sanity check¶
As a small check, I confirm that the number of masks equals the number of labels (K) and that masks have plausible pixel counts. This is mainly to catch accidental issues (like empty segments) before using these masks in Part 4.
# Quick mask sanity for the first image
p0, y0, _ = picked[0] # one test sample (path, label, prob)
g0 = cv2.imread(p0, cv2.IMREAD_GRAYSCALE).astype(np.float32)/255.0 # [0,1] grayscale
seg0, _ = compute_superpixels_from_gray(g0) # run SLIC on this image
labels0, masks0 = masks_from_segments(seg0) # 1 binary mask per superpixel
print(f"K={len(labels0)} masks; first mask has {masks0[0].sum()} pixels (1s).")
K=89 masks; first mask has 586 pixels (1s).
4 Ordering Segments from Least Important¶
4.1 Per-image superpixels (reuse from Part 3)¶
For each selected test image, I reuse the SLIC settings from Part 3 to obtain a label map of ~100 superpixels. To avoid recomputing, a small dictionary caches the label maps keyed by file path. This keeps later scoring steps quick and makes the results reproducible across cells.
# Uses compute_superpixels_from_gray(...) from earlier.
seg_cache = {} # path -> seg labels (memoize per file)
def get_segments(path, n_segments=100, compactness=10.0, sigma=1.0, use_lab=False):
if path in seg_cache:
return seg_cache[path] # hit: reuse computed SLIC
g = cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32)/255.0 # [0,1] grayscale
seg, _ = compute_superpixels_from_gray(g, n_segments, compactness, sigma, use_lab)
seg_cache[path] = seg # cache for later calls
return seg
4.2 Drop scores per superpixel¶
The idea is the same as in Part 2, but now the occluder is one superpixel at a time.
For a given image and its label:
- compute the original probability for the true class,
- for each segment
k, hide only that segment (fill with the neutral value from Part 2), - run the model again and record the drop =
p_true(orig) − p_true(occluded k).
A positive drop means the model became less confident for the true class when that segment was hidden → that segment likely helped the decision.
A negative drop means the model became more confident after hiding the segment → that region may be distracting or off-task.
For simplicity (and to match the assignment), I call occlusion_drop once per segment. Batching the masked images would be faster, but wasn’t necessary for this small number of segments.
# Per-spec implementation: call occlusion_drop once per superpixel.
def segment_drops_simple(path: str, label: int, seg: np.ndarray,
neutral="dataset") -> Tuple[np.ndarray, np.ndarray, float]:
"""
Returns:
drops[K]: Δp_true for each superpixel k
areas[K]: area fraction for each k
p_true_orig: original p_true for the image
"""
g = cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0 # [0,1] grayscale
H, W = g.shape
K = int(seg.max() + 1) # assume seg labels are 0..K-1
drops = np.zeros(K, np.float32)
areas = np.zeros(K, np.float32)
# baseline p_true with no occlusion (mask = all ones)
_, p0, _, _ = occlusion_drop(
g, np.ones((H, W), np.uint8), model_occ,
true_label=label, neutral=neutral, device=DEVICE
)
p_true_orig = p0
for k in range(K):
# mask that occludes only superpixel k (1=keep, 0=occlude k)
m = np.ones((H, W), np.uint8); m[seg == k] = 0
drop, _, _, _ = occlusion_drop(
g, m, model_occ,
true_label=label, neutral=neutral, device=DEVICE
)
drops[k] = drop # positive drop ⇒ k supports the true class
areas[k] = (seg == k).sum() / (H * W) # relative area of k
return drops, areas, p_true_orig
4.3 Ranking and a compact table¶
After computing the per-segment drops, I sort them from least to most important (smallest → largest drop). The table shows the bottom few and top few segments with:
Δp_true(signed drop),|Δp|(magnitude),Δp / 1% area(drop normalized by how much of the image the segment covers), and- the area of the segment as a percentage of the image.
The area-normalized column is helpful for spotting tiny but influential regions versus large, mildly informative ones. I keep both the raw drop and the normalized value because each highlights different behaviors.
def rank_and_table_for_image(path, label, seg, drops, areas, top=5, bottom=5):
# Sort by signed Δp (ascending): more negative ⇒ stronger counter-evidence; more positive ⇒ supporting evidence.
order = np.argsort(drops)
K = len(drops)
# Select bottom-k and top-k segments under the signed metric.
head = order[:bottom]
tail = order[-top:] if K > top else order
def row(k):
# Normalize by segment size to report Δp per 1% image area; guard against near-zero areas.
per_1pct = drops[k] / max(areas[k] * 100.0, 1e-9)
return [
int(k), # segment id
f"{drops[k]:.3f}", # signed Δp_true
f"{abs(drops[k]):.3f}", # |Δp|
f"{per_1pct:.3f}", # Δp per 1% area
f"{areas[k] * 100:.1f}%" # area (%)
]
# Emit a compact summary table for quick inspection of extremes.
rows = [row(k) for k in list(head) + list(tail)]
print(f"{os.path.basename(path)} | y={label} | K={K}")
print(tabulate(rows,
headers=["seg", "Δp_true", "|Δp|", "Δp / 1% area", "area"],
tablefmt="github"))
# Return full ranking (ascending by signed Δp) for downstream use.
return order
4.4 Running the ranking on the selected images¶
I run the procedure on the same eight test images used in Part 2. For each image, I:
- fetch (or compute) its SLIC labels,
- compute the drop for every superpixel using the dataset-mean filler,
- print the ranked summary, and
- store the outputs (
drops, areas, order, etc.) for Part 5 visualizations.
superpixel_results = {} # path -> dict(seg, drops, areas, order, p_true_orig, label)
for path, label, prob in picked[:8]:
# SLIC segmentation (Part iii helper). Fixed hyperparams for reproducibility/consistency across images.
# use_lab=False since inputs are grayscale; sigma smooths small artifacts before SLIC.
seg = get_segments(path, n_segments=100, compactness=10.0, sigma=1.0, use_lab=False)
# Spec-compliant scoring: one occlusion_drop call per segment (O(K) passes).
drops, areas, p_true_orig = segment_drops_simple(path, label, seg, neutral="dataset")
# Tabulate extremes and recover the global ranking (signed Δp: negative→counter, positive→support).
order = rank_and_table_for_image(path, label, seg, drops, areas, top=5, bottom=5)
# Cache for Part (v): retain both raw effects and the ranking for later visualization/analysis.
superpixel_results[path] = {
"label": label, # ground-truth class id used for p_true
"seg": seg, # SLIC label map (0..K-1)
"drops": drops, # per-seg Δp_true (occlude-k vs baseline)
"areas": areas, # per-seg area fraction in [0,1]
"order": order, # argsort indices (ascending by signed Δp)
"p_true_orig": p_true_orig # baseline p_true with no occlusion
}
person294_bacteria_1386.png | y=1 | K=89 | seg | Δp_true | |Δp| | Δp / 1% area | area | |-------|-----------|--------|----------------|--------| | 16 | -0.011 | 0.011 | -0.01 | 1.1% | | 87 | -0.009 | 0.009 | -0.01 | 0.9% | | 79 | -0.008 | 0.008 | -0.012 | 0.7% | | 29 | -0.006 | 0.006 | -0.004 | 1.5% | | 68 | -0.005 | 0.005 | -0.004 | 1.1% | | 19 | 0.014 | 0.014 | 0.017 | 0.8% | | 12 | 0.019 | 0.019 | 0.016 | 1.2% | | 10 | 0.021 | 0.021 | 0.012 | 1.8% | | 14 | 0.023 | 0.023 | 0.029 | 0.8% | | 5 | 0.024 | 0.024 | 0.018 | 1.3% | person1352_bacteria_3444.png | y=1 | K=97 | seg | Δp_true | |Δp| | Δp / 1% area | area | |-------|-----------|--------|----------------|--------| | 78 | -0 | 0 | -0 | 1.0% | | 65 | -0 | 0 | -0 | 2.0% | | 62 | -0 | 0 | -0 | 1.2% | | 54 | -0 | 0 | -0 | 1.0% | | 64 | -0 | 0 | -0 | 1.0% | | 27 | 0.001 | 0.001 | 0.001 | 1.2% | | 25 | 0.001 | 0.001 | 0.001 | 1.3% | | 12 | 0.002 | 0.002 | 0.002 | 1.2% | | 0 | 0.002 | 0.002 | 0.003 | 0.6% | | 15 | 0.003 | 0.003 | 0.002 | 1.3% | person294_virus_611.png | y=1 | K=98 | seg | Δp_true | |Δp| | Δp / 1% area | area | |-------|-----------|--------|----------------|--------| | 68 | -0.001 | 0.001 | -0.001 | 1.1% | | 69 | -0.001 | 0.001 | -0.001 | 0.8% | | 27 | -0.001 | 0.001 | -0 | 1.3% | | 79 | -0 | 0 | -0 | 1.2% | | 26 | -0 | 0 | -0 | 1.1% | | 9 | 0.001 | 0.001 | 0.001 | 0.9% | | 4 | 0.001 | 0.001 | 0.001 | 1.4% | | 24 | 0.001 | 0.001 | 0.001 | 1.4% | | 43 | 0.001 | 0.001 | 0.001 | 2.0% | | 14 | 0.002 | 0.002 | 0.003 | 0.8% | person1063_virus_1765.png | y=1 | K=94 | seg | Δp_true | |Δp| | Δp / 1% area | area | |-------|-----------|--------|----------------|--------| | 1 | -0.004 | 0.004 | -0.005 | 0.7% | | 66 | -0.003 | 0.003 | -0.003 | 1.1% | | 73 | -0.002 | 0.002 | -0.003 | 0.9% | | 38 | -0.002 | 0.002 | -0.001 | 1.3% | | 16 | -0.002 | 0.002 | -0.002 | 0.7% | | 6 | 0.007 | 0.007 | 0.006 | 1.2% | | 37 | 0.008 | 0.008 | 0.008 | 1.0% | | 45 | 0.011 | 0.011 | 0.005 | 2.3% | | 52 | 0.013 | 0.013 | 0.009 | 1.4% | | 21 | 0.021 | 0.021 | 0.011 | 2.0% | person554_virus_1094.png | y=1 | K=97 | seg | Δp_true | |Δp| | Δp / 1% area | area | |-------|-----------|--------|----------------|--------| | 46 | -0.049 | 0.049 | -0.038 | 1.3% | | 58 | -0.042 | 0.042 | -0.03 | 1.4% | | 6 | -0.038 | 0.038 | -0.042 | 0.9% | | 75 | -0.035 | 0.035 | -0.048 | 0.7% | | 67 | -0.035 | 0.035 | -0.024 | 1.4% | | 8 | 0.059 | 0.059 | 0.05 | 1.2% | | 9 | 0.06 | 0.06 | 0.079 | 0.8% | | 32 | 0.073 | 0.073 | 0.042 | 1.7% | | 37 | 0.079 | 0.079 | 0.069 | 1.1% | | 47 | 0.08 | 0.08 | 0.06 | 1.3% | person520_virus_1039.png | y=1 | K=99 | seg | Δp_true | |Δp| | Δp / 1% area | area | |-------|-----------|--------|----------------|--------| | 62 | -0.034 | 0.034 | -0.038 | 0.9% | | 23 | -0.033 | 0.033 | -0.032 | 1.1% | | 75 | -0.032 | 0.032 | -0.029 | 1.1% | | 47 | -0.031 | 0.031 | -0.031 | 1.0% | | 11 | -0.029 | 0.029 | -0.016 | 1.8% | | 92 | 0.022 | 0.022 | 0.017 | 1.3% | | 44 | 0.031 | 0.031 | 0.038 | 0.8% | | 80 | 0.033 | 0.033 | 0.035 | 1.0% | | 91 | 0.04 | 0.04 | 0.027 | 1.5% | | 79 | 0.053 | 0.053 | 0.035 | 1.5% | person1273_virus_2191.png | y=1 | K=94 | seg | Δp_true | |Δp| | Δp / 1% area | area | |-------|-----------|--------|----------------|--------| | 79 | -0.033 | 0.033 | -0.032 | 1.0% | | 31 | -0.027 | 0.027 | -0.029 | 0.9% | | 15 | -0.02 | 0.02 | -0.012 | 1.7% | | 88 | -0.02 | 0.02 | -0.015 | 1.3% | | 62 | -0.019 | 0.019 | -0.021 | 0.9% | | 58 | 0.031 | 0.031 | 0.027 | 1.1% | | 63 | 0.033 | 0.033 | 0.033 | 1.0% | | 35 | 0.036 | 0.036 | 0.022 | 1.6% | | 55 | 0.042 | 0.042 | 0.053 | 0.8% | | 46 | 0.057 | 0.057 | 0.037 | 1.5% | person1588_virus_2762.png | y=1 | K=98 | seg | Δp_true | |Δp| | Δp / 1% area | area | |-------|-----------|--------|----------------|--------| | 85 | -0.096 | 0.096 | -0.157 | 0.6% | | 79 | -0.074 | 0.074 | -0.071 | 1.0% | | 35 | -0.052 | 0.052 | -0.021 | 2.5% | | 22 | -0.045 | 0.045 | -0.07 | 0.6% | | 97 | -0.041 | 0.041 | -0.05 | 0.8% | | 48 | 0.061 | 0.061 | 0.063 | 1.0% | | 59 | 0.061 | 0.061 | 0.053 | 1.1% | | 76 | 0.064 | 0.064 | 0.053 | 1.2% | | 68 | 0.065 | 0.065 | 0.062 | 1.1% | | 84 | 0.084 | 0.084 | 0.062 | 1.4% |
4.4.1 Quick observations across confidence levels¶
Looking across both very confident cases (Part 2 selections ≥0.95) and some with more moderate probabilities (~0.8–0.9), a few trends emerge:
Many near-zero drops: A large share of superpixels hardly affect the prediction, which is expected. These are often background or smooth tissue areas, where hiding them doesn’t change the model’s confidence.
Positive vs. negative drops:
- In high-confidence images, most influential segments are inside the lung fields, and drops are small but consistently positive — occluding them reduces confidence, which makes sense.
- In lower-confidence cases, I observed both stronger positive and stronger negative drops. Some superpixels pushed the model in the wrong direction (negative drop: occluding them improved the prediction). These were often at lung borders or noisy regions.
Magnitude differences: Drops in moderate-confidence images can be larger in absolute value (e.g., –0.07 or +0.08), showing the model is more sensitive to local perturbations when it is less certain overall.
Area normalization: Normalizing by area occasionally highlights small superpixels with disproportionately high influence, while larger ones tend to have diluted effects. It’s a reminder that small but distinctive regions (e.g., focal opacity) can matter more than broad homogeneous ones.
Overall, the ranking seems to capture a mix of expected behavior (lung regions matter most) and some surprising sensitivities (border or background segments that reduce noise when occluded). Especially in less confident predictions, the model’s reliance on certain regions appears more unstable.
5 Visualization of Results¶
After ranking the superpixels by importance, the final step is to visualize these results.
The goal here is to connect the quantitative drop values back to the original X-rays, so we can see where the model is focusing when making decisions.
5.1 Heatmap builder¶
To make the importance of each superpixel interpretable, I map the drop values (Δp_true) back to the pixel grid.
This produces a heatmap where higher drops are shown in stronger red tones.
By default, negative values are clipped to zero so that only positively contributing regions are emphasized.
def superpixel_heatmap(seg: np.ndarray, drops: np.ndarray, clip_neg: bool = True):
"""
seg: [H, W] int labels in 0..K-1 (assumed dense; no gaps)
drops: [K] array of Δp_true per superpixel (same K as seg.max()+1)
clip_neg: if True, negative contributions are zeroed (focus on supportive evidence)
returns: heat[H, W] float, per-pixel value inherited from its segment
"""
# Copy avoids mutating caller's drops; np.maximum keeps dtype (float32/64) as-is.
vals = drops.copy()
if clip_neg:
vals = np.maximum(vals, 0.0) # clamp to [0, ∞)
# Vectorized gather: each pixel takes the value of its segment id.
# Assumes seg.dtype is integer and all labels < len(vals).
heat = vals[seg]
return heat
5.2 Cumulative drop curve¶
As an optional check, I compute a cumulative occlusion curve: progressively removing superpixels from least to most important and tracking the resulting probability drop.
While not strictly additive, this gives a sense of how model confidence depends on different fractions of the image.
Some non-intuitive fluctuations are expected due to interactions between regions.
@torch.no_grad()
def cumulative_drop_curve(path: str, label: int, seg: np.ndarray, order: np.ndarray,
neutral: str = "dataset"):
"""
order: indices sorted least→most important (from Part iv)
Returns:
x: fraction of superpixels occluded (i/K)
y: cumulative Δp_true after occluding the first i segments in 'order'
"""
g = cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0 # [0,1] grayscale
H, W = g.shape
K = int(seg.max() + 1)
# y_vals[i-1] stores Δp_true after occluding the first i superpixels in 'order'
y_vals = []
for i in range(1, K + 1):
# Occlude the i least-important segments so far (1=keep, 0=occlude).
occlude_ids = set(order[:i]) # membership test for label ids
m = np.ones((H, W), np.uint8)
m[np.isin(seg, list(occlude_ids))] = 0
# Single forward pass per prefix; records cumulative effect at step i.
drop, _, _, _ = occlusion_drop(
g, m, model_occ, true_label=label, neutral=neutral, device=DEVICE
)
y_vals.append(drop)
# x advances uniformly by superpixel count; area-weighted x would require segment areas.
x = np.arange(1, K + 1) / K
return x, np.array(y_vals, dtype=np.float32)
5.3 Per-image visualization¶
For each selected test image, I combine the different elements into a row:
- Original image for reference.
- Superpixel boundaries to show the segmentation.
- Importance heatmap overlayed on the X-ray.
- Bar chart of the top-k most important superpixels.
- Optionally, the cumulative curve for global context.
This gives both a localized and aggregated view of importance.
plt.rcParams["figure.dpi"] = 110
def show_part5_for_image(path: str, res: dict, topk: int = 20, clip_neg: bool = True, show_cumulative: bool = False):
"""
res: superpixel_results[path] from Part (iv)
{ 'label', 'seg', 'drops', 'areas', 'order', 'p_true_orig' }
"""
label = res["label"]; seg = res["seg"]; drops = res["drops"]; areas = res["areas"]; order = res["order"]
# Base image in [0,1] grayscale; RGB triplet used only for boundary overlay.
g = cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0
rgb = np.dstack([g, g, g])
# Heatmap derived from per-segment Δp; optionally zero out negatives (support-only view).
heat = superpixel_heatmap(seg, drops, clip_neg=clip_neg)
pos_vals = heat[heat > 0]
vmax = float(np.percentile(pos_vals, 99)) if pos_vals.size else 1.0 # robust upper bound for colormap
# Top-k selection by raw Δp_true (largest positive effects); preserve descending order for bars.
idx_sorted = np.argsort(drops) # least→most (signed)
top_idx = idx_sorted[-topk:] if len(idx_sorted) > topk else idx_sorted
top_idx = top_idx[np.argsort(-drops[top_idx])] # reorder to descending Δp_true
# Layout: add cumulative panel if requested.
ncols = 5 if show_cumulative else 4
fig, axs = plt.subplots(
1, ncols,
figsize=(13.5 if show_cumulative else 11.5, 3.0),
constrained_layout=True
)
fig.suptitle(f"{os.path.basename(path)} | y={label} | K={int(seg.max()+1)}", fontsize=10)
# (1) Original
axs[0].imshow(g, cmap='gray', vmin=0, vmax=1)
axs[0].set_title("Original"); axs[0].axis('off')
# (2) Superpixel boundaries (thick red for visibility on grayscale)
axs[1].imshow(mark_boundaries(rgb, seg, color=(1, 0, 0), mode='thick'))
axs[1].set_title("Superpixel boundaries"); axs[1].axis('off')
# (3) Heatmap overlay — Δp_true shown in Reds; alpha blends with the base image.
axs[2].imshow(g, cmap='gray', vmin=0, vmax=1)
im = axs[2].imshow(heat, cmap='Reds', alpha=0.45, vmin=0, vmax=vmax)
axs[2].set_title("Importance heatmap (Δp_true)"); axs[2].axis('off')
cbar = plt.colorbar(im, ax=axs[2], fraction=0.046, pad=0.04)
cbar.set_label("Δp_true", fontsize=8)
# (4) Bar chart — top-k segments by Δp_true
axb = axs[3]
axb.bar(range(len(top_idx)), drops[top_idx])
axb.set_title(f"Top-{len(top_idx)} superpixels by Δp_true")
axb.set_xlabel("segment id (ranked)"); axb.set_ylabel("Δp_true")
axb.set_xticks(range(len(top_idx)))
axb.set_xticklabels([int(i) for i in top_idx], rotation=90, fontsize=7)
axb.grid(True, alpha=0.3)
# (5) Optional cumulative curve — least→most important per signed Δp
if show_cumulative:
x, y = cumulative_drop_curve(path, label, seg, order, neutral="dataset")
axc = axs[4]
axc.plot(x, y, marker='o', linewidth=1)
# Include original p_true from cache for quick reference.
p_true = res.get("p_true_orig", None)
title = f"{os.path.basename(path)} | y={label} | K={int(seg.max()+1)}"
if p_true is not None:
title += f" | p_true(orig)={p_true:.3f}"
fig.suptitle(title, fontsize=10)
axc.set_xlabel("Fraction of segments occluded")
axc.set_ylabel("Δp_true (cumulative)")
axc.grid(True, alpha=0.3)
# --- SAVE to {PROJECT_ROOT}/figures/final_visualizations_<image-stem>.png ---
fig_dir = Path(PROJ_ROOT) / "figures"
fig_dir.mkdir(parents=True, exist_ok=True)
out_path = fig_dir / f"final_visualizations_{Path(path).stem}.png"
fig.savefig(out_path, dpi=150, bbox_inches="tight")
plt.show()
5.4 Displaying final results¶
I apply the visualization function to the eight selected test images.
The outputs show varying patterns: in high-confidence predictions, importance often lies on outer or border regions (suggesting shortcut reliance), whereas in lower-confidence cases, the focus shifts more toward the lungs and inner anatomy.
These results highlight both the strengths and limits of the preprocessing pipeline.
# Requires: superpixel_results (from Part 4) populated for these paths.
for path, _, _ in picked[:8]:
# Per-image visualization row: Original | Boundaries | Heatmap | Bars | Cumulative curve
show_part5_for_image(
path,
superpixel_results[path],
topk=20, # show top-20 segments by Δp_true in the bar chart
clip_neg=True, # heatmap focuses on supportive (positive Δp) regions
show_cumulative=True # include least→most cumulative drop curve
)
Output hidden; open in https://colab.research.google.com to view.
5.5 Conclusions¶
Overall, the visualizations suggest that while the model sometimes relies on background or border cues, especially in highly confident cases, the preprocessing pipeline has partially reduced this effect.
In the lower-confidence cases, the model appears to attend more to clinically meaningful lung regions, which is a positive sign.
The cumulative occlusion curves also confirm that predictive signal is unevenly distributed across segments, though their non-monotonic patterns show that occlusion is not perfectly additive.
Together, these experiments illustrate the value of interpretability methods: they reveal both potential vulnerabilities (shortcut learning) and encouraging evidence that, under some conditions, the model does focus on relevant anatomy.
Further remarks¶
The validation and test performance of the model was unusually high and these visualizations suppot the argument that not all of that accuracy may come from lung-related features. Further investigation could include:
- exploring stricter preprocessing or artifact removal to limit shortcut reliance,
- comparing with other interpretability methods (e.g., Grad-CAM, SHAP) for complementary insights,
- and testing whether these patterns persist on larger or more balanced datasets.
These steps could help clarify whether the strong performance reflects genuine learning of disease-relevant structures or partial dependence on spurious cues.